# %% import math import os from dataclasses import dataclass import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from tqdm import tqdm import matplotlib.pyplot as plt # %% class Config: # Data n_train: int = 8000 n_val: int = 1000 T: int = 120 F: int = 6 noise_std: float = 0.05 # Second modality F2: int = 6 T2_mult: int = 5 # modality-2 has 5x timesteps T2: int = T * T2_mult # 600 # Model hidden_size: int = 32 latent_dim: int = 32 num_layers: int = 1 # Training batch_size: int = 128 lr: float = 1e-3 epochs: int = 30 kl_warmup_epochs: int = 30 grad_clip: float = 1.0 # Optional recon weights w1: float = 1.0 w2: float = 1.0 # System device: str = "cuda" if torch.cuda.is_available() else "cpu" out_dir: str = "runs/ts_vae_2modal" # %% def triangle_wave(phase): """ phase in radians. Returns triangle wave in [-1, 1]. """ # Convert to cycles [0,1) u = (phase / (2 * math.pi)) % 1.0 # Triangle: 4|u-0.5|-1 in [-1,1] return 4.0 * abs(u - 0.5) - 1.0 def sawtooth_wave(phase): """ Sawtooth in [-1,1]. """ u = (phase / (2 * math.pi)) % 1.0 return 2.0 * u - 1.0 def square_wave(phase): """ Square in {-1, +1}. """ return 1.0 if math.sin(phase) >= 0.0 else -1.0 class TwoModalSyntheticTimeSeries(Dataset): """ Returns: x1: [T, F1] (low-rate modality) x2: [5T, F2] (high-rate modality) Both span the same time interval [-1, 1] but x2 is sampled 5x more densely. Modality A: sinusoids Modality B: different shape (triangle by default) with higher base frequency """ def __init__( self, n: int, T: int, F1: int, F2: int, noise_std1: float = 0.05, noise_std2: float = 0.05, seed: int = 0, shape2: str = "triangle", # "triangle" | "sawtooth" | "square" freq2_mult: float = 2.0, # shape frequency multiplier within modality B ): super().__init__() self.n = n self.T = T self.T2 = 5 * T self.F1 = F1 self.F2 = F2 self.noise_std1 = noise_std1 self.noise_std2 = noise_std2 self.shape2 = shape2 self.freq2_mult = freq2_mult rng = np.random.RandomState(seed) # Per-sample latent factors shared across both modalities (for correlation) self.amp = rng.uniform(0.6, 1.4, size=n).astype(np.float32) # amplitude self.phase = rng.uniform(-math.pi, math.pi, size=n).astype(np.float32) self.trend = rng.uniform(-0.3, 0.3, size=n).astype(np.float32) # linear trend self.bias = rng.uniform(-0.2, 0.2, size=n).astype(np.float32) # offset # Feature-wise frequencies (modality A) self.freq1 = rng.uniform(0.6, 1.6, size=F1).astype(np.float32) # Feature-wise frequencies (modality B), generally higher self.freq2 = rng.uniform(1.5, 3.5, size=F2).astype(np.float32) * float(freq2_mult) def __len__(self): return self.n def _wave2(self, phase: float) -> float: if self.shape2 == "triangle": return float(triangle_wave(phase)) if self.shape2 == "sawtooth": return float(sawtooth_wave(phase)) if self.shape2 == "square": return float(square_wave(phase)) raise ValueError(f"Unknown shape2={self.shape2}") def __getitem__(self, idx): # Time grids: same span, different sampling density t1 = torch.linspace(-1.0, 1.0, steps=self.T, dtype=torch.float32) # [T] t2 = torch.linspace(-1.0, 1.0, steps=self.T2, dtype=torch.float32) # [5T] amp = float(self.amp[idx]) ph = float(self.phase[idx]) tr = float(self.trend[idx]) b = float(self.bias[idx]) # -------------------- # Modality A (sinusoids) # -------------------- x1 = torch.zeros(self.T, self.F1, dtype=torch.float32) for f in range(self.F1): w = float(self.freq1[f]) * 2.0 * math.pi # correlated structure: shared amp/phase/trend/bias x1[:, f] = ( amp * torch.sin(w * t1 + ph) + 0.3 * amp * torch.sin(2.0 * w * t1 + 0.5 * ph) # add harmonic + tr * t1 + b ) x1 = x1 + self.noise_std1 * torch.randn_like(x1) # -------------------- # Modality B (different shape, higher sampling rate) # -------------------- x2 = torch.zeros(self.T2, self.F2, dtype=torch.float32) for f in range(self.F2): w = float(self.freq2[f]) * 2.0 * math.pi # build phase per timestep # same shared amp/phase/trend/bias so modalities correlate # different shape -> triangle/saw/square vals = [] for tt in t2.tolist(): ph_t = w * tt + (ph + 0.4) # small deterministic shift vs modality A vals.append(self._wave2(ph_t)) wave = torch.tensor(vals, dtype=torch.float32) x2[:, f] = ( amp * wave + 0.15 * amp * torch.sin(3.0 * w * t2 + ph) # slight extra structure + tr * t2 + b ) x2 = x2 + self.noise_std2 * torch.randn_like(x2) return x1, x2 # %% class EncoderGRU(nn.Module): def __init__(self, input_dim: int, hidden_size: int, num_layers: int, bidirectional: bool = False): super().__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.bidirectional = bidirectional self.dir_mult = 2 if bidirectional else 1 self.gru = nn.GRU( input_size=input_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=bidirectional, dropout=0.1 if num_layers > 1 else 0.0 ) def forward(self, x): # x: [B,T,F] _, h = self.gru(x) # h: [num_layers*dir, B, hidden] # take last layer (and both directions if bidi) h_last = h.view(self.num_layers, self.dir_mult, x.size(0), self.hidden_size)[-1] # [dir, B, hidden] emb = h_last.transpose(0, 1).reshape(x.size(0), self.dir_mult * self.hidden_size) # [B, dir*hidden] return emb # %% class DecoderGRUTime(nn.Module): def __init__(self, latent_dim: int, hidden_size: int, num_layers: int, output_dim: int): super().__init__() self.latent_dim = latent_dim self.hidden_size = hidden_size self.num_layers = num_layers self.output_dim = output_dim self.z_to_h0 = nn.Linear(latent_dim, num_layers * hidden_size) self.gru = nn.GRU( input_size=latent_dim + 1, # z plus time channel hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=0.1 if num_layers > 1 else 0.0, ) self.to_out = nn.Linear(hidden_size, output_dim) def forward(self, z, T: int): # z: [B,Z] B, Z = z.shape if Z != self.latent_dim: raise RuntimeError(f"Decoder expected z dim {self.latent_dim}, got {Z}") # Repeat z across time: [B,T,Z] z_rep = z.unsqueeze(1).expand(B, T, Z) # Time channel: [B,T,1] t = torch.linspace(-1.0, 1.0, steps=T, device=z.device, dtype=z.dtype) t = t.view(1, T, 1).expand(B, T, 1) # Concatenate: [B,T,Z+1] inp = torch.cat([z_rep, t], dim=-1) # Init hidden: [num_layers,B,hidden] h0 = self.z_to_h0(z).view(self.num_layers, B, self.hidden_size).contiguous() y, _ = self.gru(inp, h0) mu_hat = self.to_out(y) # [B,T,F] return mu_hat # %% class TwoModalTimeSeriesVAE(nn.Module): def __init__(self, F1: int, F2: int, hidden_size: int, num_layers: int, latent_dim: int, bidi: bool = False): super().__init__() self.latent_dim = latent_dim self.enc1 = EncoderGRU(F1, hidden_size, num_layers, bidirectional=bidi) self.enc2 = EncoderGRU(F2, hidden_size, num_layers, bidirectional=bidi) enc_out_dim = (2 if bidi else 1) * hidden_size fusion_in = enc_out_dim + enc_out_dim self.fuse = nn.Sequential( nn.Linear(fusion_in, fusion_in), nn.ReLU(), nn.Linear(fusion_in, fusion_in), nn.ReLU(), ) self.to_mu = nn.Linear(fusion_in, latent_dim) self.to_logvar = nn.Linear(fusion_in, latent_dim) self.dec1 = DecoderGRUTime(latent_dim, hidden_size, num_layers, F1) self.dec2 = DecoderGRUTime(latent_dim, hidden_size, num_layers, F2) def encode(self, x1, x2): e1 = self.enc1(x1) # [B,H] e2 = self.enc2(x2) # [B,H] h = torch.cat([e1, e2], dim=-1) h = self.fuse(h) mu_z = self.to_mu(h) logvar_z = self.to_logvar(h) return mu_z, logvar_z def forward(self, x1, x2, sample: bool = True, eps_scale: float = 1.0): # x1: [B,T,F1], x2: [B,5T,F2] B, T1, _ = x1.shape _, T2, _ = x2.shape mu_z, logvar_z = self.encode(x1, x2) if sample: z = reparameterize(mu_z, logvar_z, eps_scale=eps_scale) else: z = mu_z mu_x1 = self.dec1(z, T1) # [B,T,F1] mu_x2 = self.dec2(z, T2) # [B,5T,F2] return mu_x1, mu_x2, mu_z, logvar_z # %% # ----------------------------- # Losses # ----------------------------- def reparameterize(mu, logvar, eps_scale: float = 1.0): # mu, logvar: [B,Z] std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) * eps_scale return mu + eps * std def kl_diag_gaussian(mu, logvar, free_bits: float = 0.0): # returns KL per sample: [B] # KL(q||p) with p=N(0,I): 0.5 * sum(exp(logvar) + mu^2 - 1 - logvar) kl = 0.5 * (torch.exp(logvar) + mu**2 - 1.0 - logvar).sum(dim=-1) if free_bits > 0.0: kl = torch.clamp(kl, min=free_bits * mu.shape[-1]) return kl def gaussian_nll(x, mu_x, logvar_x): # x, mu_x, logvar_x: [B, T, F] return 0.5 * (logvar_x + (x - mu_x)**2 / torch.exp(logvar_x)) # [B,T,F] def compute_loss_two_modal(x1, x2, mu_x1, mu_x2, mu_z, logvar_z, beta: float, w1: float = 1.0, w2: float = 1.0): # recon per sample recon1 = F.mse_loss(mu_x1, x1, reduction="none").mean(dim=(1, 2)) # [B] recon2 = F.mse_loss(mu_x2, x2, reduction="none").mean(dim=(1, 2)) # [B] recon = w1 * recon1 + w2 * recon2 kl = kl_diag_gaussian(mu_z, logvar_z, free_bits=0.0) # [B] loss = (recon + beta * kl).mean() return loss, recon.mean().item(), kl.mean().item(), recon1.mean().item(), recon2.mean().item() # %% # ----------------------------- # Utilities: plotting # ----------------------------- def plot_reconstructions(x, x_hat, outpath: str, n_examples: int = 3): """ Plots a few examples, feature 0 only to keep it readable. """ x = x.detach().cpu().numpy() x_hat = x_hat.detach().cpu().numpy() n = min(n_examples, x.shape[0]) plt.figure(figsize=(10, 3 * n)) for i in range(n): plt.subplot(n, 1, i + 1) plt.plot(x[i, :, 0], label="real") plt.plot(x_hat[i, :, 0], label="recon") plt.legend() plt.title(f"Example {i} (feature 0)") plt.tight_layout() out_dir = os.path.dirname(outpath) if out_dir: os.makedirs(out_dir, exist_ok=True) plt.savefig(outpath) plt.close() def plot_samples(model, T: int, device: str, outpath: str, n_samples: int = 3): model.eval() with torch.no_grad(): z = torch.randn(n_samples, model.encoder.to_mu.out_features, device=device) x_synth = model.decoder(z, T) # mean-only decoder -> [n, T, F] x_synth = x_synth.detach().cpu().numpy() plt.figure(figsize=(10, 3 * n_samples)) for i in range(n_samples): plt.subplot(n_samples, 1, i + 1) plt.plot(x_synth[i, :, 0]) plt.title(f"Synthetic sample {i} (feature 0)") plt.tight_layout() out_dir = os.path.dirname(outpath) if out_dir: os.makedirs(out_dir, exist_ok=True) plt.savefig(outpath) plt.close() # %% def det_check_two_modal(model, loader, device, tag="DET", n_batches=1): model.eval() xs1, xs2 = [], [] mus1, mus2 = [], [] zs, logvars = [], [] it = iter(loader) for _ in range(n_batches): x1, x2 = next(it) x1 = x1.to(device) x2 = x2.to(device) mu_x1, mu_x2, mu_z, logvar_z = model(x1, x2, sample=False, eps_scale=0.0) xs1.append(x1); xs2.append(x2) mus1.append(mu_x1); mus2.append(mu_x2) zs.append(mu_z); logvars.append(logvar_z) x1 = torch.cat(xs1, dim=0) # [B,T,F1] x2 = torch.cat(xs2, dim=0) # [B,5T,F2] mu1 = torch.cat(mus1, dim=0) mu2 = torch.cat(mus2, dim=0) mu_z = torch.cat(zs, dim=0) # [B,Z] logvar_z = torch.cat(logvars, dim=0) def _stats(x, mu): B, T, Fdim = x.shape xf = x.reshape(B*T, Fdim) mf = mu.reshape(B*T, Fdim) mse_feat = ((mf - xf) ** 2).mean(dim=0) mse_all = mse_feat.mean() mean_x = xf.mean(dim=0) mean_m = mf.mean(dim=0) mean_err = (mean_m - mean_x) std_x = xf.std(dim=0) std_m = mf.std(dim=0) std_ratio = std_m / (std_x + 1e-8) x0 = xf - mean_x m0 = mf - mean_m corr = (x0 * m0).mean(dim=0) / ((x0.std(dim=0) * m0.std(dim=0)) + 1e-8) return mse_all, mse_feat, std_x, std_m, std_ratio, mean_err, corr s1 = _stats(x1, mu1) s2 = _stats(x2, mu2) print(f"{tag} MOD1:") print(f" mse_all = {s1[0].item():.6f}") print(f" mse_feat = {s1[1].detach().cpu().numpy()}") print(f" std_x = {s1[2].detach().cpu().numpy()}") print(f" std_mu = {s1[3].detach().cpu().numpy()}") print(f" std_ratio(std_mu/std_x) = {s1[4].detach().cpu().numpy()}") print(f" mean_err(mu-x) = {s1[5].detach().cpu().numpy()}") print(f" corr(mu,x) = {s1[6].detach().cpu().numpy()}") print(f"{tag} MOD2:") print(f" mse_all = {s2[0].item():.6f}") print(f" mse_feat = {s2[1].detach().cpu().numpy()}") print(f" std_x = {s2[2].detach().cpu().numpy()}") print(f" std_mu = {s2[3].detach().cpu().numpy()}") print(f" std_ratio(std_mu/std_x) = {s2[4].detach().cpu().numpy()}") print(f" mean_err(mu-x) = {s2[5].detach().cpu().numpy()}") print(f" corr(mu,x) = {s2[6].detach().cpu().numpy()}") print( f"{tag} LATENT:" f" mu_z mean={mu_z.mean().item():.4f}, std={mu_z.std().item():.4f}" f" | logvar_z mean={logvar_z.mean().item():.4f}" ) # %% def main(): cfg = Config() model = TwoModalTimeSeriesVAE( F1=cfg.F, F2=cfg.F2, hidden_size=cfg.hidden_size, num_layers=cfg.num_layers, latent_dim=cfg.latent_dim, bidi=False, ).to(cfg.device) optim = torch.optim.Adam(model.parameters(), lr=cfg.lr) train_ds = TwoModalSyntheticTimeSeries( n=cfg.n_train, T=cfg.T, F1=cfg.F, # reuse your existing F for modality A F2=cfg.F, # or choose a different number noise_std1=cfg.noise_std, noise_std2=cfg.noise_std, seed=1, shape2="triangle", freq2_mult=2.0, ) val_ds = TwoModalSyntheticTimeSeries( n=cfg.n_val, T=cfg.T, F1=cfg.F, F2=cfg.F, noise_std1=cfg.noise_std, noise_std2=cfg.noise_std, seed=2, shape2="triangle", freq2_mult=2.0, ) print("torch version:", torch.__version__) print("cuda available:", torch.cuda.is_available()) print("cuda device count:", torch.cuda.device_count()) if torch.cuda.is_available(): print("current device:", torch.cuda.current_device()) print("device name:", torch.cuda.get_device_name(torch.cuda.current_device())) print("memory allocated (MB):", torch.cuda.memory_allocated() / 1024**2) print("memory reserved (MB):", torch.cuda.memory_reserved() / 1024**2) train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, drop_last=True) val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, drop_last=False) print("model device:", next(model.parameters()).device) batch = next(iter(train_loader)) x1, x2 = batch # if 2-modal; otherwise just x = batch x1 = x1.to(cfg.device) print("batch device:", x1.device) # ----------------------------- # Training schedule params # ----------------------------- ae_epochs = 10 beta_cap = 0.01 beta_warmup_epochs = 10 for epoch in range(1, cfg.epochs + 1): model.train() if epoch <= ae_epochs: sample = False eps_scale = 0.0 beta = 0.0 else: sample = True eps_scale = 1.0 beta = min(beta_cap, (epoch - ae_epochs) / beta_warmup_epochs * beta_cap) if epoch == ae_epochs + 1: for g in optim.param_groups: g["lr"] = cfg.lr * 0.1 if epoch == ae_epochs + 1: for g in optim.param_groups: g["lr"] = cfg.lr * 0.1 tr_loss = 0.0 tr_recon = 0.0 tr_kl = 0.0 steps = 0 for (x1, x2) in tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs}"): x1 = x1.to(cfg.device) x2 = x2.to(cfg.device) mu_x1, mu_x2, mu_z, logvar_z = model(x1, x2, sample=sample, eps_scale=eps_scale) loss, recon_m, kl_m, recon1_m, recon2_m = compute_loss_two_modal( x1, x2, mu_x1, mu_x2, mu_z, logvar_z, beta=beta, w1=1.0, w2=1.0 ) optim.zero_grad(set_to_none=True) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) optim.step() tr_loss += loss.item() tr_recon += recon_m tr_kl += kl_m steps += 1 tr_loss /= steps tr_recon /= steps tr_kl /= steps # ----------------------------- # Validation (deterministic for stable diagnostics) # ----------------------------- model.eval() va_loss = 0.0 va_recon = 0.0 va_kl = 0.0 va_steps = 0 fb = 0.0 if beta == 0.0 else 0.1 # keep consistent with training with torch.no_grad(): for (x1, x2) in val_loader: x1 = x1.to(cfg.device) x2 = x2.to(cfg.device) mu_x1, mu_x2, mu_z, logvar_z = model(x1, x2, sample=False) recon1 = F.mse_loss(mu_x1, x1, reduction="none").mean(dim=(1,2)) recon2 = F.mse_loss(mu_x2, x2, reduction="none").mean(dim=(1,2)) recon = recon1 + recon2 fb = 0.0 if beta == 0.0 else 0.1 kl = kl_diag_gaussian(mu_z, logvar_z, free_bits=fb) loss = (recon + beta * kl).mean() va_loss += loss.item() va_recon += recon.mean().item() va_kl += kl.mean().item() va_steps += 1 va_loss /= va_steps va_recon /= va_steps va_kl /= va_steps # Call once per epoch if epoch in [1, 5, 10, 20, cfg.epochs]: det_check_two_modal(model, train_loader, cfg.device, "TRAIN") det_check_two_modal(model, val_loader, cfg.device, "VAL") print( f"Epoch {epoch:02d} | beta={beta:.3f} | " f"Train loss={tr_loss:.4f} (recon={tr_recon:.4f}, kl={tr_kl:.4f}) | " f"Val loss={va_loss:.4f} (recon={va_recon:.4f}, kl={va_kl:.4f})" ) # Plots (2-modal) if epoch in {1, 5, 10, cfg.epochs}: with torch.no_grad(): batch = next(iter(val_loader)) # DataLoader can return (x1, x2) as tuple or list x1_plot, x2_plot = batch[0].to(cfg.device), batch[1].to(cfg.device) mu_x1_plot, mu_x2_plot, _, _ = model(x1_plot, x2_plot, sample=False) plot_reconstructions( x1_plot[:8], mu_x1_plot[:8], os.path.join(cfg.out_dir, f"recon_m1_epoch{epoch}.png") ) plot_reconstructions( x2_plot[:8], mu_x2_plot[:8], os.path.join(cfg.out_dir, f"recon_m2_epoch{epoch}.png") ) model_path = os.path.join(cfg.out_dir, "ts_vae.pt") model_dir = os.path.dirname(model_path) if model_dir and not os.path.exists(model_dir): os.makedirs(model_dir, exist_ok=True) torch.save(model.state_dict(), model_path) print("Saved to:", model_path) # %% main()