Changed name of synth_data_gen.py to synth-data-gen.py to avoid confusing the markdown engine
652 lines
20 KiB
Python
652 lines
20 KiB
Python
# %%
|
|
import math
|
|
import os
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from tqdm import tqdm
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
# %%
|
|
class Config:
|
|
# Data
|
|
n_train: int = 8000
|
|
n_val: int = 1000
|
|
T: int = 120
|
|
F: int = 6
|
|
noise_std: float = 0.05
|
|
|
|
# Second modality
|
|
F2: int = 6
|
|
T2_mult: int = 5 # modality-2 has 5x timesteps
|
|
T2: int = T * T2_mult # 600
|
|
|
|
# Model
|
|
hidden_size: int = 32
|
|
latent_dim: int = 32
|
|
num_layers: int = 1
|
|
|
|
# Training
|
|
batch_size: int = 128
|
|
lr: float = 1e-3
|
|
epochs: int = 30
|
|
kl_warmup_epochs: int = 30
|
|
grad_clip: float = 1.0
|
|
|
|
# Optional recon weights
|
|
w1: float = 1.0
|
|
w2: float = 1.0
|
|
|
|
# System
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
out_dir: str = "runs/ts_vae_2modal"
|
|
|
|
|
|
# %%
|
|
def triangle_wave(phase):
|
|
"""
|
|
phase in radians. Returns triangle wave in [-1, 1].
|
|
"""
|
|
# Convert to cycles [0,1)
|
|
u = (phase / (2 * math.pi)) % 1.0
|
|
# Triangle: 4|u-0.5|-1 in [-1,1]
|
|
return 4.0 * abs(u - 0.5) - 1.0
|
|
|
|
def sawtooth_wave(phase):
|
|
"""
|
|
Sawtooth in [-1,1].
|
|
"""
|
|
u = (phase / (2 * math.pi)) % 1.0
|
|
return 2.0 * u - 1.0
|
|
|
|
def square_wave(phase):
|
|
"""
|
|
Square in {-1, +1}.
|
|
"""
|
|
return 1.0 if math.sin(phase) >= 0.0 else -1.0
|
|
|
|
|
|
class TwoModalSyntheticTimeSeries(Dataset):
|
|
"""
|
|
Returns:
|
|
x1: [T, F1] (low-rate modality)
|
|
x2: [5T, F2] (high-rate modality)
|
|
Both span the same time interval [-1, 1] but x2 is sampled 5x more densely.
|
|
|
|
Modality A: sinusoids
|
|
Modality B: different shape (triangle by default) with higher base frequency
|
|
"""
|
|
def __init__(
|
|
self,
|
|
n: int,
|
|
T: int,
|
|
F1: int,
|
|
F2: int,
|
|
noise_std1: float = 0.05,
|
|
noise_std2: float = 0.05,
|
|
seed: int = 0,
|
|
shape2: str = "triangle", # "triangle" | "sawtooth" | "square"
|
|
freq2_mult: float = 2.0, # shape frequency multiplier within modality B
|
|
):
|
|
super().__init__()
|
|
self.n = n
|
|
self.T = T
|
|
self.T2 = 5 * T
|
|
self.F1 = F1
|
|
self.F2 = F2
|
|
self.noise_std1 = noise_std1
|
|
self.noise_std2 = noise_std2
|
|
self.shape2 = shape2
|
|
self.freq2_mult = freq2_mult
|
|
|
|
rng = np.random.RandomState(seed)
|
|
|
|
# Per-sample latent factors shared across both modalities (for correlation)
|
|
self.amp = rng.uniform(0.6, 1.4, size=n).astype(np.float32) # amplitude
|
|
self.phase = rng.uniform(-math.pi, math.pi, size=n).astype(np.float32)
|
|
self.trend = rng.uniform(-0.3, 0.3, size=n).astype(np.float32) # linear trend
|
|
self.bias = rng.uniform(-0.2, 0.2, size=n).astype(np.float32) # offset
|
|
|
|
# Feature-wise frequencies (modality A)
|
|
self.freq1 = rng.uniform(0.6, 1.6, size=F1).astype(np.float32)
|
|
|
|
# Feature-wise frequencies (modality B), generally higher
|
|
self.freq2 = rng.uniform(1.5, 3.5, size=F2).astype(np.float32) * float(freq2_mult)
|
|
|
|
def __len__(self):
|
|
return self.n
|
|
|
|
def _wave2(self, phase: float) -> float:
|
|
if self.shape2 == "triangle":
|
|
return float(triangle_wave(phase))
|
|
if self.shape2 == "sawtooth":
|
|
return float(sawtooth_wave(phase))
|
|
if self.shape2 == "square":
|
|
return float(square_wave(phase))
|
|
raise ValueError(f"Unknown shape2={self.shape2}")
|
|
|
|
def __getitem__(self, idx):
|
|
# Time grids: same span, different sampling density
|
|
t1 = torch.linspace(-1.0, 1.0, steps=self.T, dtype=torch.float32) # [T]
|
|
t2 = torch.linspace(-1.0, 1.0, steps=self.T2, dtype=torch.float32) # [5T]
|
|
|
|
amp = float(self.amp[idx])
|
|
ph = float(self.phase[idx])
|
|
tr = float(self.trend[idx])
|
|
b = float(self.bias[idx])
|
|
|
|
# --------------------
|
|
# Modality A (sinusoids)
|
|
# --------------------
|
|
x1 = torch.zeros(self.T, self.F1, dtype=torch.float32)
|
|
for f in range(self.F1):
|
|
w = float(self.freq1[f]) * 2.0 * math.pi
|
|
# correlated structure: shared amp/phase/trend/bias
|
|
x1[:, f] = (
|
|
amp * torch.sin(w * t1 + ph)
|
|
+ 0.3 * amp * torch.sin(2.0 * w * t1 + 0.5 * ph) # add harmonic
|
|
+ tr * t1
|
|
+ b
|
|
)
|
|
x1 = x1 + self.noise_std1 * torch.randn_like(x1)
|
|
|
|
# --------------------
|
|
# Modality B (different shape, higher sampling rate)
|
|
# --------------------
|
|
x2 = torch.zeros(self.T2, self.F2, dtype=torch.float32)
|
|
for f in range(self.F2):
|
|
w = float(self.freq2[f]) * 2.0 * math.pi
|
|
# build phase per timestep
|
|
# same shared amp/phase/trend/bias so modalities correlate
|
|
# different shape -> triangle/saw/square
|
|
vals = []
|
|
for tt in t2.tolist():
|
|
ph_t = w * tt + (ph + 0.4) # small deterministic shift vs modality A
|
|
vals.append(self._wave2(ph_t))
|
|
wave = torch.tensor(vals, dtype=torch.float32)
|
|
x2[:, f] = (
|
|
amp * wave
|
|
+ 0.15 * amp * torch.sin(3.0 * w * t2 + ph) # slight extra structure
|
|
+ tr * t2
|
|
+ b
|
|
)
|
|
x2 = x2 + self.noise_std2 * torch.randn_like(x2)
|
|
|
|
return x1, x2
|
|
|
|
|
|
# %%
|
|
class EncoderGRU(nn.Module):
|
|
def __init__(self, input_dim: int, hidden_size: int, num_layers: int, bidirectional: bool = False):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.bidirectional = bidirectional
|
|
self.dir_mult = 2 if bidirectional else 1
|
|
|
|
self.gru = nn.GRU(
|
|
input_size=input_dim,
|
|
hidden_size=hidden_size,
|
|
num_layers=num_layers,
|
|
batch_first=True,
|
|
bidirectional=bidirectional,
|
|
dropout=0.1 if num_layers > 1 else 0.0
|
|
)
|
|
|
|
def forward(self, x):
|
|
# x: [B,T,F]
|
|
_, h = self.gru(x) # h: [num_layers*dir, B, hidden]
|
|
# take last layer (and both directions if bidi)
|
|
h_last = h.view(self.num_layers, self.dir_mult, x.size(0), self.hidden_size)[-1] # [dir, B, hidden]
|
|
emb = h_last.transpose(0, 1).reshape(x.size(0), self.dir_mult * self.hidden_size) # [B, dir*hidden]
|
|
return emb
|
|
|
|
|
|
# %%
|
|
class DecoderGRUTime(nn.Module):
|
|
def __init__(self, latent_dim: int, hidden_size: int, num_layers: int, output_dim: int):
|
|
super().__init__()
|
|
self.latent_dim = latent_dim
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.output_dim = output_dim
|
|
|
|
self.z_to_h0 = nn.Linear(latent_dim, num_layers * hidden_size)
|
|
|
|
self.gru = nn.GRU(
|
|
input_size=latent_dim + 1, # z plus time channel
|
|
hidden_size=hidden_size,
|
|
num_layers=num_layers,
|
|
batch_first=True,
|
|
dropout=0.1 if num_layers > 1 else 0.0,
|
|
)
|
|
self.to_out = nn.Linear(hidden_size, output_dim)
|
|
|
|
def forward(self, z, T: int):
|
|
# z: [B,Z]
|
|
B, Z = z.shape
|
|
if Z != self.latent_dim:
|
|
raise RuntimeError(f"Decoder expected z dim {self.latent_dim}, got {Z}")
|
|
|
|
# Repeat z across time: [B,T,Z]
|
|
z_rep = z.unsqueeze(1).expand(B, T, Z)
|
|
|
|
# Time channel: [B,T,1]
|
|
t = torch.linspace(-1.0, 1.0, steps=T, device=z.device, dtype=z.dtype)
|
|
t = t.view(1, T, 1).expand(B, T, 1)
|
|
|
|
# Concatenate: [B,T,Z+1]
|
|
inp = torch.cat([z_rep, t], dim=-1)
|
|
|
|
# Init hidden: [num_layers,B,hidden]
|
|
h0 = self.z_to_h0(z).view(self.num_layers, B, self.hidden_size).contiguous()
|
|
|
|
y, _ = self.gru(inp, h0)
|
|
mu_hat = self.to_out(y) # [B,T,F]
|
|
return mu_hat
|
|
|
|
# %%
|
|
class TwoModalTimeSeriesVAE(nn.Module):
|
|
def __init__(self, F1: int, F2: int, hidden_size: int, num_layers: int, latent_dim: int, bidi: bool = False):
|
|
super().__init__()
|
|
self.latent_dim = latent_dim
|
|
|
|
self.enc1 = EncoderGRU(F1, hidden_size, num_layers, bidirectional=bidi)
|
|
self.enc2 = EncoderGRU(F2, hidden_size, num_layers, bidirectional=bidi)
|
|
|
|
enc_out_dim = (2 if bidi else 1) * hidden_size
|
|
fusion_in = enc_out_dim + enc_out_dim
|
|
|
|
self.fuse = nn.Sequential(
|
|
nn.Linear(fusion_in, fusion_in),
|
|
nn.ReLU(),
|
|
nn.Linear(fusion_in, fusion_in),
|
|
nn.ReLU(),
|
|
)
|
|
self.to_mu = nn.Linear(fusion_in, latent_dim)
|
|
self.to_logvar = nn.Linear(fusion_in, latent_dim)
|
|
|
|
self.dec1 = DecoderGRUTime(latent_dim, hidden_size, num_layers, F1)
|
|
self.dec2 = DecoderGRUTime(latent_dim, hidden_size, num_layers, F2)
|
|
|
|
def encode(self, x1, x2):
|
|
e1 = self.enc1(x1) # [B,H]
|
|
e2 = self.enc2(x2) # [B,H]
|
|
h = torch.cat([e1, e2], dim=-1)
|
|
h = self.fuse(h)
|
|
mu_z = self.to_mu(h)
|
|
logvar_z = self.to_logvar(h)
|
|
return mu_z, logvar_z
|
|
|
|
def forward(self, x1, x2, sample: bool = True, eps_scale: float = 1.0):
|
|
# x1: [B,T,F1], x2: [B,5T,F2]
|
|
B, T1, _ = x1.shape
|
|
_, T2, _ = x2.shape
|
|
|
|
mu_z, logvar_z = self.encode(x1, x2)
|
|
if sample:
|
|
z = reparameterize(mu_z, logvar_z, eps_scale=eps_scale)
|
|
else:
|
|
z = mu_z
|
|
|
|
mu_x1 = self.dec1(z, T1) # [B,T,F1]
|
|
mu_x2 = self.dec2(z, T2) # [B,5T,F2]
|
|
return mu_x1, mu_x2, mu_z, logvar_z
|
|
|
|
|
|
# %%
|
|
# -----------------------------
|
|
# Losses
|
|
# -----------------------------
|
|
def reparameterize(mu, logvar, eps_scale: float = 1.0):
|
|
# mu, logvar: [B,Z]
|
|
std = torch.exp(0.5 * logvar)
|
|
eps = torch.randn_like(std) * eps_scale
|
|
return mu + eps * std
|
|
|
|
def kl_diag_gaussian(mu, logvar, free_bits: float = 0.0):
|
|
# returns KL per sample: [B]
|
|
# KL(q||p) with p=N(0,I): 0.5 * sum(exp(logvar) + mu^2 - 1 - logvar)
|
|
kl = 0.5 * (torch.exp(logvar) + mu**2 - 1.0 - logvar).sum(dim=-1)
|
|
if free_bits > 0.0:
|
|
kl = torch.clamp(kl, min=free_bits * mu.shape[-1])
|
|
return kl
|
|
|
|
def gaussian_nll(x, mu_x, logvar_x):
|
|
# x, mu_x, logvar_x: [B, T, F]
|
|
return 0.5 * (logvar_x + (x - mu_x)**2 / torch.exp(logvar_x)) # [B,T,F]
|
|
|
|
def compute_loss_two_modal(x1, x2, mu_x1, mu_x2, mu_z, logvar_z, beta: float, w1: float = 1.0, w2: float = 1.0):
|
|
# recon per sample
|
|
recon1 = F.mse_loss(mu_x1, x1, reduction="none").mean(dim=(1, 2)) # [B]
|
|
recon2 = F.mse_loss(mu_x2, x2, reduction="none").mean(dim=(1, 2)) # [B]
|
|
recon = w1 * recon1 + w2 * recon2
|
|
|
|
kl = kl_diag_gaussian(mu_z, logvar_z, free_bits=0.0) # [B]
|
|
loss = (recon + beta * kl).mean()
|
|
|
|
return loss, recon.mean().item(), kl.mean().item(), recon1.mean().item(), recon2.mean().item()
|
|
|
|
|
|
|
|
# %%
|
|
# -----------------------------
|
|
# Utilities: plotting
|
|
# -----------------------------
|
|
def plot_reconstructions(x, x_hat, outpath: str, n_examples: int = 3):
|
|
"""
|
|
Plots a few examples, feature 0 only to keep it readable.
|
|
"""
|
|
x = x.detach().cpu().numpy()
|
|
x_hat = x_hat.detach().cpu().numpy()
|
|
|
|
n = min(n_examples, x.shape[0])
|
|
plt.figure(figsize=(10, 3 * n))
|
|
for i in range(n):
|
|
plt.subplot(n, 1, i + 1)
|
|
plt.plot(x[i, :, 0], label="real")
|
|
plt.plot(x_hat[i, :, 0], label="recon")
|
|
plt.legend()
|
|
plt.title(f"Example {i} (feature 0)")
|
|
plt.tight_layout()
|
|
out_dir = os.path.dirname(outpath)
|
|
if out_dir:
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
plt.savefig(outpath)
|
|
plt.close()
|
|
|
|
|
|
def plot_samples(model, T: int, device: str, outpath: str, n_samples: int = 3):
|
|
model.eval()
|
|
with torch.no_grad():
|
|
z = torch.randn(n_samples, model.encoder.to_mu.out_features, device=device)
|
|
x_synth = model.decoder(z, T) # mean-only decoder -> [n, T, F]
|
|
|
|
x_synth = x_synth.detach().cpu().numpy()
|
|
|
|
plt.figure(figsize=(10, 3 * n_samples))
|
|
for i in range(n_samples):
|
|
plt.subplot(n_samples, 1, i + 1)
|
|
plt.plot(x_synth[i, :, 0])
|
|
plt.title(f"Synthetic sample {i} (feature 0)")
|
|
plt.tight_layout()
|
|
out_dir = os.path.dirname(outpath)
|
|
if out_dir:
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
plt.savefig(outpath)
|
|
plt.close()
|
|
|
|
|
|
# %%
|
|
def det_check_two_modal(model, loader, device, tag="DET", n_batches=1):
|
|
model.eval()
|
|
xs1, xs2 = [], []
|
|
mus1, mus2 = [], []
|
|
zs, logvars = [], []
|
|
|
|
it = iter(loader)
|
|
for _ in range(n_batches):
|
|
x1, x2 = next(it)
|
|
x1 = x1.to(device)
|
|
x2 = x2.to(device)
|
|
mu_x1, mu_x2, mu_z, logvar_z = model(x1, x2, sample=False, eps_scale=0.0)
|
|
xs1.append(x1); xs2.append(x2)
|
|
mus1.append(mu_x1); mus2.append(mu_x2)
|
|
zs.append(mu_z); logvars.append(logvar_z)
|
|
|
|
x1 = torch.cat(xs1, dim=0) # [B,T,F1]
|
|
x2 = torch.cat(xs2, dim=0) # [B,5T,F2]
|
|
mu1 = torch.cat(mus1, dim=0)
|
|
mu2 = torch.cat(mus2, dim=0)
|
|
mu_z = torch.cat(zs, dim=0) # [B,Z]
|
|
logvar_z = torch.cat(logvars, dim=0)
|
|
|
|
def _stats(x, mu):
|
|
B, T, Fdim = x.shape
|
|
xf = x.reshape(B*T, Fdim)
|
|
mf = mu.reshape(B*T, Fdim)
|
|
|
|
mse_feat = ((mf - xf) ** 2).mean(dim=0)
|
|
mse_all = mse_feat.mean()
|
|
|
|
mean_x = xf.mean(dim=0)
|
|
mean_m = mf.mean(dim=0)
|
|
mean_err = (mean_m - mean_x)
|
|
|
|
std_x = xf.std(dim=0)
|
|
std_m = mf.std(dim=0)
|
|
std_ratio = std_m / (std_x + 1e-8)
|
|
|
|
x0 = xf - mean_x
|
|
m0 = mf - mean_m
|
|
corr = (x0 * m0).mean(dim=0) / ((x0.std(dim=0) * m0.std(dim=0)) + 1e-8)
|
|
|
|
return mse_all, mse_feat, std_x, std_m, std_ratio, mean_err, corr
|
|
|
|
s1 = _stats(x1, mu1)
|
|
s2 = _stats(x2, mu2)
|
|
|
|
print(f"{tag} MOD1:")
|
|
print(f" mse_all = {s1[0].item():.6f}")
|
|
print(f" mse_feat = {s1[1].detach().cpu().numpy()}")
|
|
print(f" std_x = {s1[2].detach().cpu().numpy()}")
|
|
print(f" std_mu = {s1[3].detach().cpu().numpy()}")
|
|
print(f" std_ratio(std_mu/std_x) = {s1[4].detach().cpu().numpy()}")
|
|
print(f" mean_err(mu-x) = {s1[5].detach().cpu().numpy()}")
|
|
print(f" corr(mu,x) = {s1[6].detach().cpu().numpy()}")
|
|
|
|
print(f"{tag} MOD2:")
|
|
print(f" mse_all = {s2[0].item():.6f}")
|
|
print(f" mse_feat = {s2[1].detach().cpu().numpy()}")
|
|
print(f" std_x = {s2[2].detach().cpu().numpy()}")
|
|
print(f" std_mu = {s2[3].detach().cpu().numpy()}")
|
|
print(f" std_ratio(std_mu/std_x) = {s2[4].detach().cpu().numpy()}")
|
|
print(f" mean_err(mu-x) = {s2[5].detach().cpu().numpy()}")
|
|
print(f" corr(mu,x) = {s2[6].detach().cpu().numpy()}")
|
|
print(
|
|
f"{tag} LATENT:"
|
|
f" mu_z mean={mu_z.mean().item():.4f}, std={mu_z.std().item():.4f}"
|
|
f" | logvar_z mean={logvar_z.mean().item():.4f}"
|
|
)
|
|
|
|
# %%
|
|
def main():
|
|
cfg = Config()
|
|
|
|
model = TwoModalTimeSeriesVAE(
|
|
F1=cfg.F,
|
|
F2=cfg.F2,
|
|
hidden_size=cfg.hidden_size,
|
|
num_layers=cfg.num_layers,
|
|
latent_dim=cfg.latent_dim,
|
|
bidi=False,
|
|
).to(cfg.device)
|
|
|
|
optim = torch.optim.Adam(model.parameters(), lr=cfg.lr)
|
|
|
|
train_ds = TwoModalSyntheticTimeSeries(
|
|
n=cfg.n_train,
|
|
T=cfg.T,
|
|
F1=cfg.F, # reuse your existing F for modality A
|
|
F2=cfg.F, # or choose a different number
|
|
noise_std1=cfg.noise_std,
|
|
noise_std2=cfg.noise_std,
|
|
seed=1,
|
|
shape2="triangle",
|
|
freq2_mult=2.0,
|
|
)
|
|
|
|
val_ds = TwoModalSyntheticTimeSeries(
|
|
n=cfg.n_val,
|
|
T=cfg.T,
|
|
F1=cfg.F,
|
|
F2=cfg.F,
|
|
noise_std1=cfg.noise_std,
|
|
noise_std2=cfg.noise_std,
|
|
seed=2,
|
|
shape2="triangle",
|
|
freq2_mult=2.0,
|
|
)
|
|
|
|
print("torch version:", torch.__version__)
|
|
print("cuda available:", torch.cuda.is_available())
|
|
print("cuda device count:", torch.cuda.device_count())
|
|
if torch.cuda.is_available():
|
|
print("current device:", torch.cuda.current_device())
|
|
print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
|
|
print("memory allocated (MB):", torch.cuda.memory_allocated() / 1024**2)
|
|
print("memory reserved (MB):", torch.cuda.memory_reserved() / 1024**2)
|
|
|
|
|
|
|
|
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
|
|
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, drop_last=False)
|
|
|
|
print("model device:", next(model.parameters()).device)
|
|
batch = next(iter(train_loader))
|
|
x1, x2 = batch # if 2-modal; otherwise just x = batch
|
|
x1 = x1.to(cfg.device)
|
|
print("batch device:", x1.device)
|
|
|
|
# -----------------------------
|
|
# Training schedule params
|
|
# -----------------------------
|
|
ae_epochs = 10
|
|
beta_cap = 0.01
|
|
beta_warmup_epochs = 10
|
|
|
|
|
|
for epoch in range(1, cfg.epochs + 1):
|
|
model.train()
|
|
|
|
if epoch <= ae_epochs:
|
|
sample = False
|
|
eps_scale = 0.0
|
|
beta = 0.0
|
|
else:
|
|
sample = True
|
|
eps_scale = 1.0
|
|
beta = min(beta_cap, (epoch - ae_epochs) / beta_warmup_epochs * beta_cap)
|
|
if epoch == ae_epochs + 1:
|
|
for g in optim.param_groups:
|
|
g["lr"] = cfg.lr * 0.1
|
|
|
|
if epoch == ae_epochs + 1:
|
|
for g in optim.param_groups:
|
|
g["lr"] = cfg.lr * 0.1
|
|
|
|
tr_loss = 0.0
|
|
tr_recon = 0.0
|
|
tr_kl = 0.0
|
|
steps = 0
|
|
|
|
|
|
for (x1, x2) in tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs}"):
|
|
x1 = x1.to(cfg.device)
|
|
x2 = x2.to(cfg.device)
|
|
|
|
mu_x1, mu_x2, mu_z, logvar_z = model(x1, x2, sample=sample, eps_scale=eps_scale)
|
|
|
|
loss, recon_m, kl_m, recon1_m, recon2_m = compute_loss_two_modal(
|
|
x1, x2, mu_x1, mu_x2, mu_z, logvar_z, beta=beta, w1=1.0, w2=1.0
|
|
)
|
|
|
|
optim.zero_grad(set_to_none=True)
|
|
loss.backward()
|
|
nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
|
|
optim.step()
|
|
|
|
tr_loss += loss.item()
|
|
tr_recon += recon_m
|
|
tr_kl += kl_m
|
|
steps += 1
|
|
|
|
|
|
tr_loss /= steps
|
|
tr_recon /= steps
|
|
tr_kl /= steps
|
|
|
|
# -----------------------------
|
|
# Validation (deterministic for stable diagnostics)
|
|
# -----------------------------
|
|
model.eval()
|
|
va_loss = 0.0
|
|
va_recon = 0.0
|
|
va_kl = 0.0
|
|
va_steps = 0
|
|
fb = 0.0 if beta == 0.0 else 0.1 # keep consistent with training
|
|
|
|
with torch.no_grad():
|
|
for (x1, x2) in val_loader:
|
|
x1 = x1.to(cfg.device)
|
|
x2 = x2.to(cfg.device)
|
|
|
|
mu_x1, mu_x2, mu_z, logvar_z = model(x1, x2, sample=False)
|
|
|
|
recon1 = F.mse_loss(mu_x1, x1, reduction="none").mean(dim=(1,2))
|
|
recon2 = F.mse_loss(mu_x2, x2, reduction="none").mean(dim=(1,2))
|
|
recon = recon1 + recon2
|
|
|
|
fb = 0.0 if beta == 0.0 else 0.1
|
|
kl = kl_diag_gaussian(mu_z, logvar_z, free_bits=fb)
|
|
loss = (recon + beta * kl).mean()
|
|
|
|
|
|
va_loss += loss.item()
|
|
va_recon += recon.mean().item()
|
|
va_kl += kl.mean().item()
|
|
va_steps += 1
|
|
|
|
|
|
va_loss /= va_steps
|
|
va_recon /= va_steps
|
|
va_kl /= va_steps
|
|
|
|
# Call once per epoch
|
|
if epoch in [1, 5, 10, 20, cfg.epochs]:
|
|
det_check_two_modal(model, train_loader, cfg.device, "TRAIN")
|
|
det_check_two_modal(model, val_loader, cfg.device, "VAL")
|
|
|
|
print(
|
|
f"Epoch {epoch:02d} | beta={beta:.3f} | "
|
|
f"Train loss={tr_loss:.4f} (recon={tr_recon:.4f}, kl={tr_kl:.4f}) | "
|
|
f"Val loss={va_loss:.4f} (recon={va_recon:.4f}, kl={va_kl:.4f})"
|
|
)
|
|
|
|
# Plots (2-modal)
|
|
if epoch in {1, 5, 10, cfg.epochs}:
|
|
with torch.no_grad():
|
|
batch = next(iter(val_loader))
|
|
# DataLoader can return (x1, x2) as tuple or list
|
|
x1_plot, x2_plot = batch[0].to(cfg.device), batch[1].to(cfg.device)
|
|
|
|
mu_x1_plot, mu_x2_plot, _, _ = model(x1_plot, x2_plot, sample=False)
|
|
|
|
plot_reconstructions(
|
|
x1_plot[:8], mu_x1_plot[:8],
|
|
os.path.join(cfg.out_dir, f"recon_m1_epoch{epoch}.png")
|
|
)
|
|
plot_reconstructions(
|
|
x2_plot[:8], mu_x2_plot[:8],
|
|
os.path.join(cfg.out_dir, f"recon_m2_epoch{epoch}.png")
|
|
)
|
|
|
|
model_path = os.path.join(cfg.out_dir, "ts_vae.pt")
|
|
model_dir = os.path.dirname(model_path)
|
|
if model_dir and not os.path.exists(model_dir):
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
torch.save(model.state_dict(), model_path)
|
|
print("Saved to:", model_path)
|
|
|
|
|
|
# %%
|
|
main()
|
|
|
|
|