Files
slurm-gpu-test/synth-data-gen.py
Erik Thuning 0a0f955cc2 Added README
Changed name of synth_data_gen.py to synth-data-gen.py to avoid confusing the markdown engine
2026-03-04 19:30:27 +01:00

652 lines
20 KiB
Python

# %%
import math
import os
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
# %%
class Config:
# Data
n_train: int = 8000
n_val: int = 1000
T: int = 120
F: int = 6
noise_std: float = 0.05
# Second modality
F2: int = 6
T2_mult: int = 5 # modality-2 has 5x timesteps
T2: int = T * T2_mult # 600
# Model
hidden_size: int = 32
latent_dim: int = 32
num_layers: int = 1
# Training
batch_size: int = 128
lr: float = 1e-3
epochs: int = 30
kl_warmup_epochs: int = 30
grad_clip: float = 1.0
# Optional recon weights
w1: float = 1.0
w2: float = 1.0
# System
device: str = "cuda" if torch.cuda.is_available() else "cpu"
out_dir: str = "runs/ts_vae_2modal"
# %%
def triangle_wave(phase):
"""
phase in radians. Returns triangle wave in [-1, 1].
"""
# Convert to cycles [0,1)
u = (phase / (2 * math.pi)) % 1.0
# Triangle: 4|u-0.5|-1 in [-1,1]
return 4.0 * abs(u - 0.5) - 1.0
def sawtooth_wave(phase):
"""
Sawtooth in [-1,1].
"""
u = (phase / (2 * math.pi)) % 1.0
return 2.0 * u - 1.0
def square_wave(phase):
"""
Square in {-1, +1}.
"""
return 1.0 if math.sin(phase) >= 0.0 else -1.0
class TwoModalSyntheticTimeSeries(Dataset):
"""
Returns:
x1: [T, F1] (low-rate modality)
x2: [5T, F2] (high-rate modality)
Both span the same time interval [-1, 1] but x2 is sampled 5x more densely.
Modality A: sinusoids
Modality B: different shape (triangle by default) with higher base frequency
"""
def __init__(
self,
n: int,
T: int,
F1: int,
F2: int,
noise_std1: float = 0.05,
noise_std2: float = 0.05,
seed: int = 0,
shape2: str = "triangle", # "triangle" | "sawtooth" | "square"
freq2_mult: float = 2.0, # shape frequency multiplier within modality B
):
super().__init__()
self.n = n
self.T = T
self.T2 = 5 * T
self.F1 = F1
self.F2 = F2
self.noise_std1 = noise_std1
self.noise_std2 = noise_std2
self.shape2 = shape2
self.freq2_mult = freq2_mult
rng = np.random.RandomState(seed)
# Per-sample latent factors shared across both modalities (for correlation)
self.amp = rng.uniform(0.6, 1.4, size=n).astype(np.float32) # amplitude
self.phase = rng.uniform(-math.pi, math.pi, size=n).astype(np.float32)
self.trend = rng.uniform(-0.3, 0.3, size=n).astype(np.float32) # linear trend
self.bias = rng.uniform(-0.2, 0.2, size=n).astype(np.float32) # offset
# Feature-wise frequencies (modality A)
self.freq1 = rng.uniform(0.6, 1.6, size=F1).astype(np.float32)
# Feature-wise frequencies (modality B), generally higher
self.freq2 = rng.uniform(1.5, 3.5, size=F2).astype(np.float32) * float(freq2_mult)
def __len__(self):
return self.n
def _wave2(self, phase: float) -> float:
if self.shape2 == "triangle":
return float(triangle_wave(phase))
if self.shape2 == "sawtooth":
return float(sawtooth_wave(phase))
if self.shape2 == "square":
return float(square_wave(phase))
raise ValueError(f"Unknown shape2={self.shape2}")
def __getitem__(self, idx):
# Time grids: same span, different sampling density
t1 = torch.linspace(-1.0, 1.0, steps=self.T, dtype=torch.float32) # [T]
t2 = torch.linspace(-1.0, 1.0, steps=self.T2, dtype=torch.float32) # [5T]
amp = float(self.amp[idx])
ph = float(self.phase[idx])
tr = float(self.trend[idx])
b = float(self.bias[idx])
# --------------------
# Modality A (sinusoids)
# --------------------
x1 = torch.zeros(self.T, self.F1, dtype=torch.float32)
for f in range(self.F1):
w = float(self.freq1[f]) * 2.0 * math.pi
# correlated structure: shared amp/phase/trend/bias
x1[:, f] = (
amp * torch.sin(w * t1 + ph)
+ 0.3 * amp * torch.sin(2.0 * w * t1 + 0.5 * ph) # add harmonic
+ tr * t1
+ b
)
x1 = x1 + self.noise_std1 * torch.randn_like(x1)
# --------------------
# Modality B (different shape, higher sampling rate)
# --------------------
x2 = torch.zeros(self.T2, self.F2, dtype=torch.float32)
for f in range(self.F2):
w = float(self.freq2[f]) * 2.0 * math.pi
# build phase per timestep
# same shared amp/phase/trend/bias so modalities correlate
# different shape -> triangle/saw/square
vals = []
for tt in t2.tolist():
ph_t = w * tt + (ph + 0.4) # small deterministic shift vs modality A
vals.append(self._wave2(ph_t))
wave = torch.tensor(vals, dtype=torch.float32)
x2[:, f] = (
amp * wave
+ 0.15 * amp * torch.sin(3.0 * w * t2 + ph) # slight extra structure
+ tr * t2
+ b
)
x2 = x2 + self.noise_std2 * torch.randn_like(x2)
return x1, x2
# %%
class EncoderGRU(nn.Module):
def __init__(self, input_dim: int, hidden_size: int, num_layers: int, bidirectional: bool = False):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.dir_mult = 2 if bidirectional else 1
self.gru = nn.GRU(
input_size=input_dim,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional,
dropout=0.1 if num_layers > 1 else 0.0
)
def forward(self, x):
# x: [B,T,F]
_, h = self.gru(x) # h: [num_layers*dir, B, hidden]
# take last layer (and both directions if bidi)
h_last = h.view(self.num_layers, self.dir_mult, x.size(0), self.hidden_size)[-1] # [dir, B, hidden]
emb = h_last.transpose(0, 1).reshape(x.size(0), self.dir_mult * self.hidden_size) # [B, dir*hidden]
return emb
# %%
class DecoderGRUTime(nn.Module):
def __init__(self, latent_dim: int, hidden_size: int, num_layers: int, output_dim: int):
super().__init__()
self.latent_dim = latent_dim
self.hidden_size = hidden_size
self.num_layers = num_layers
self.output_dim = output_dim
self.z_to_h0 = nn.Linear(latent_dim, num_layers * hidden_size)
self.gru = nn.GRU(
input_size=latent_dim + 1, # z plus time channel
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=0.1 if num_layers > 1 else 0.0,
)
self.to_out = nn.Linear(hidden_size, output_dim)
def forward(self, z, T: int):
# z: [B,Z]
B, Z = z.shape
if Z != self.latent_dim:
raise RuntimeError(f"Decoder expected z dim {self.latent_dim}, got {Z}")
# Repeat z across time: [B,T,Z]
z_rep = z.unsqueeze(1).expand(B, T, Z)
# Time channel: [B,T,1]
t = torch.linspace(-1.0, 1.0, steps=T, device=z.device, dtype=z.dtype)
t = t.view(1, T, 1).expand(B, T, 1)
# Concatenate: [B,T,Z+1]
inp = torch.cat([z_rep, t], dim=-1)
# Init hidden: [num_layers,B,hidden]
h0 = self.z_to_h0(z).view(self.num_layers, B, self.hidden_size).contiguous()
y, _ = self.gru(inp, h0)
mu_hat = self.to_out(y) # [B,T,F]
return mu_hat
# %%
class TwoModalTimeSeriesVAE(nn.Module):
def __init__(self, F1: int, F2: int, hidden_size: int, num_layers: int, latent_dim: int, bidi: bool = False):
super().__init__()
self.latent_dim = latent_dim
self.enc1 = EncoderGRU(F1, hidden_size, num_layers, bidirectional=bidi)
self.enc2 = EncoderGRU(F2, hidden_size, num_layers, bidirectional=bidi)
enc_out_dim = (2 if bidi else 1) * hidden_size
fusion_in = enc_out_dim + enc_out_dim
self.fuse = nn.Sequential(
nn.Linear(fusion_in, fusion_in),
nn.ReLU(),
nn.Linear(fusion_in, fusion_in),
nn.ReLU(),
)
self.to_mu = nn.Linear(fusion_in, latent_dim)
self.to_logvar = nn.Linear(fusion_in, latent_dim)
self.dec1 = DecoderGRUTime(latent_dim, hidden_size, num_layers, F1)
self.dec2 = DecoderGRUTime(latent_dim, hidden_size, num_layers, F2)
def encode(self, x1, x2):
e1 = self.enc1(x1) # [B,H]
e2 = self.enc2(x2) # [B,H]
h = torch.cat([e1, e2], dim=-1)
h = self.fuse(h)
mu_z = self.to_mu(h)
logvar_z = self.to_logvar(h)
return mu_z, logvar_z
def forward(self, x1, x2, sample: bool = True, eps_scale: float = 1.0):
# x1: [B,T,F1], x2: [B,5T,F2]
B, T1, _ = x1.shape
_, T2, _ = x2.shape
mu_z, logvar_z = self.encode(x1, x2)
if sample:
z = reparameterize(mu_z, logvar_z, eps_scale=eps_scale)
else:
z = mu_z
mu_x1 = self.dec1(z, T1) # [B,T,F1]
mu_x2 = self.dec2(z, T2) # [B,5T,F2]
return mu_x1, mu_x2, mu_z, logvar_z
# %%
# -----------------------------
# Losses
# -----------------------------
def reparameterize(mu, logvar, eps_scale: float = 1.0):
# mu, logvar: [B,Z]
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std) * eps_scale
return mu + eps * std
def kl_diag_gaussian(mu, logvar, free_bits: float = 0.0):
# returns KL per sample: [B]
# KL(q||p) with p=N(0,I): 0.5 * sum(exp(logvar) + mu^2 - 1 - logvar)
kl = 0.5 * (torch.exp(logvar) + mu**2 - 1.0 - logvar).sum(dim=-1)
if free_bits > 0.0:
kl = torch.clamp(kl, min=free_bits * mu.shape[-1])
return kl
def gaussian_nll(x, mu_x, logvar_x):
# x, mu_x, logvar_x: [B, T, F]
return 0.5 * (logvar_x + (x - mu_x)**2 / torch.exp(logvar_x)) # [B,T,F]
def compute_loss_two_modal(x1, x2, mu_x1, mu_x2, mu_z, logvar_z, beta: float, w1: float = 1.0, w2: float = 1.0):
# recon per sample
recon1 = F.mse_loss(mu_x1, x1, reduction="none").mean(dim=(1, 2)) # [B]
recon2 = F.mse_loss(mu_x2, x2, reduction="none").mean(dim=(1, 2)) # [B]
recon = w1 * recon1 + w2 * recon2
kl = kl_diag_gaussian(mu_z, logvar_z, free_bits=0.0) # [B]
loss = (recon + beta * kl).mean()
return loss, recon.mean().item(), kl.mean().item(), recon1.mean().item(), recon2.mean().item()
# %%
# -----------------------------
# Utilities: plotting
# -----------------------------
def plot_reconstructions(x, x_hat, outpath: str, n_examples: int = 3):
"""
Plots a few examples, feature 0 only to keep it readable.
"""
x = x.detach().cpu().numpy()
x_hat = x_hat.detach().cpu().numpy()
n = min(n_examples, x.shape[0])
plt.figure(figsize=(10, 3 * n))
for i in range(n):
plt.subplot(n, 1, i + 1)
plt.plot(x[i, :, 0], label="real")
plt.plot(x_hat[i, :, 0], label="recon")
plt.legend()
plt.title(f"Example {i} (feature 0)")
plt.tight_layout()
out_dir = os.path.dirname(outpath)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
plt.savefig(outpath)
plt.close()
def plot_samples(model, T: int, device: str, outpath: str, n_samples: int = 3):
model.eval()
with torch.no_grad():
z = torch.randn(n_samples, model.encoder.to_mu.out_features, device=device)
x_synth = model.decoder(z, T) # mean-only decoder -> [n, T, F]
x_synth = x_synth.detach().cpu().numpy()
plt.figure(figsize=(10, 3 * n_samples))
for i in range(n_samples):
plt.subplot(n_samples, 1, i + 1)
plt.plot(x_synth[i, :, 0])
plt.title(f"Synthetic sample {i} (feature 0)")
plt.tight_layout()
out_dir = os.path.dirname(outpath)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
plt.savefig(outpath)
plt.close()
# %%
def det_check_two_modal(model, loader, device, tag="DET", n_batches=1):
model.eval()
xs1, xs2 = [], []
mus1, mus2 = [], []
zs, logvars = [], []
it = iter(loader)
for _ in range(n_batches):
x1, x2 = next(it)
x1 = x1.to(device)
x2 = x2.to(device)
mu_x1, mu_x2, mu_z, logvar_z = model(x1, x2, sample=False, eps_scale=0.0)
xs1.append(x1); xs2.append(x2)
mus1.append(mu_x1); mus2.append(mu_x2)
zs.append(mu_z); logvars.append(logvar_z)
x1 = torch.cat(xs1, dim=0) # [B,T,F1]
x2 = torch.cat(xs2, dim=0) # [B,5T,F2]
mu1 = torch.cat(mus1, dim=0)
mu2 = torch.cat(mus2, dim=0)
mu_z = torch.cat(zs, dim=0) # [B,Z]
logvar_z = torch.cat(logvars, dim=0)
def _stats(x, mu):
B, T, Fdim = x.shape
xf = x.reshape(B*T, Fdim)
mf = mu.reshape(B*T, Fdim)
mse_feat = ((mf - xf) ** 2).mean(dim=0)
mse_all = mse_feat.mean()
mean_x = xf.mean(dim=0)
mean_m = mf.mean(dim=0)
mean_err = (mean_m - mean_x)
std_x = xf.std(dim=0)
std_m = mf.std(dim=0)
std_ratio = std_m / (std_x + 1e-8)
x0 = xf - mean_x
m0 = mf - mean_m
corr = (x0 * m0).mean(dim=0) / ((x0.std(dim=0) * m0.std(dim=0)) + 1e-8)
return mse_all, mse_feat, std_x, std_m, std_ratio, mean_err, corr
s1 = _stats(x1, mu1)
s2 = _stats(x2, mu2)
print(f"{tag} MOD1:")
print(f" mse_all = {s1[0].item():.6f}")
print(f" mse_feat = {s1[1].detach().cpu().numpy()}")
print(f" std_x = {s1[2].detach().cpu().numpy()}")
print(f" std_mu = {s1[3].detach().cpu().numpy()}")
print(f" std_ratio(std_mu/std_x) = {s1[4].detach().cpu().numpy()}")
print(f" mean_err(mu-x) = {s1[5].detach().cpu().numpy()}")
print(f" corr(mu,x) = {s1[6].detach().cpu().numpy()}")
print(f"{tag} MOD2:")
print(f" mse_all = {s2[0].item():.6f}")
print(f" mse_feat = {s2[1].detach().cpu().numpy()}")
print(f" std_x = {s2[2].detach().cpu().numpy()}")
print(f" std_mu = {s2[3].detach().cpu().numpy()}")
print(f" std_ratio(std_mu/std_x) = {s2[4].detach().cpu().numpy()}")
print(f" mean_err(mu-x) = {s2[5].detach().cpu().numpy()}")
print(f" corr(mu,x) = {s2[6].detach().cpu().numpy()}")
print(
f"{tag} LATENT:"
f" mu_z mean={mu_z.mean().item():.4f}, std={mu_z.std().item():.4f}"
f" | logvar_z mean={logvar_z.mean().item():.4f}"
)
# %%
def main():
cfg = Config()
model = TwoModalTimeSeriesVAE(
F1=cfg.F,
F2=cfg.F2,
hidden_size=cfg.hidden_size,
num_layers=cfg.num_layers,
latent_dim=cfg.latent_dim,
bidi=False,
).to(cfg.device)
optim = torch.optim.Adam(model.parameters(), lr=cfg.lr)
train_ds = TwoModalSyntheticTimeSeries(
n=cfg.n_train,
T=cfg.T,
F1=cfg.F, # reuse your existing F for modality A
F2=cfg.F, # or choose a different number
noise_std1=cfg.noise_std,
noise_std2=cfg.noise_std,
seed=1,
shape2="triangle",
freq2_mult=2.0,
)
val_ds = TwoModalSyntheticTimeSeries(
n=cfg.n_val,
T=cfg.T,
F1=cfg.F,
F2=cfg.F,
noise_std1=cfg.noise_std,
noise_std2=cfg.noise_std,
seed=2,
shape2="triangle",
freq2_mult=2.0,
)
print("torch version:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
print("cuda device count:", torch.cuda.device_count())
if torch.cuda.is_available():
print("current device:", torch.cuda.current_device())
print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
print("memory allocated (MB):", torch.cuda.memory_allocated() / 1024**2)
print("memory reserved (MB):", torch.cuda.memory_reserved() / 1024**2)
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, drop_last=False)
print("model device:", next(model.parameters()).device)
batch = next(iter(train_loader))
x1, x2 = batch # if 2-modal; otherwise just x = batch
x1 = x1.to(cfg.device)
print("batch device:", x1.device)
# -----------------------------
# Training schedule params
# -----------------------------
ae_epochs = 10
beta_cap = 0.01
beta_warmup_epochs = 10
for epoch in range(1, cfg.epochs + 1):
model.train()
if epoch <= ae_epochs:
sample = False
eps_scale = 0.0
beta = 0.0
else:
sample = True
eps_scale = 1.0
beta = min(beta_cap, (epoch - ae_epochs) / beta_warmup_epochs * beta_cap)
if epoch == ae_epochs + 1:
for g in optim.param_groups:
g["lr"] = cfg.lr * 0.1
if epoch == ae_epochs + 1:
for g in optim.param_groups:
g["lr"] = cfg.lr * 0.1
tr_loss = 0.0
tr_recon = 0.0
tr_kl = 0.0
steps = 0
for (x1, x2) in tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs}"):
x1 = x1.to(cfg.device)
x2 = x2.to(cfg.device)
mu_x1, mu_x2, mu_z, logvar_z = model(x1, x2, sample=sample, eps_scale=eps_scale)
loss, recon_m, kl_m, recon1_m, recon2_m = compute_loss_two_modal(
x1, x2, mu_x1, mu_x2, mu_z, logvar_z, beta=beta, w1=1.0, w2=1.0
)
optim.zero_grad(set_to_none=True)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
optim.step()
tr_loss += loss.item()
tr_recon += recon_m
tr_kl += kl_m
steps += 1
tr_loss /= steps
tr_recon /= steps
tr_kl /= steps
# -----------------------------
# Validation (deterministic for stable diagnostics)
# -----------------------------
model.eval()
va_loss = 0.0
va_recon = 0.0
va_kl = 0.0
va_steps = 0
fb = 0.0 if beta == 0.0 else 0.1 # keep consistent with training
with torch.no_grad():
for (x1, x2) in val_loader:
x1 = x1.to(cfg.device)
x2 = x2.to(cfg.device)
mu_x1, mu_x2, mu_z, logvar_z = model(x1, x2, sample=False)
recon1 = F.mse_loss(mu_x1, x1, reduction="none").mean(dim=(1,2))
recon2 = F.mse_loss(mu_x2, x2, reduction="none").mean(dim=(1,2))
recon = recon1 + recon2
fb = 0.0 if beta == 0.0 else 0.1
kl = kl_diag_gaussian(mu_z, logvar_z, free_bits=fb)
loss = (recon + beta * kl).mean()
va_loss += loss.item()
va_recon += recon.mean().item()
va_kl += kl.mean().item()
va_steps += 1
va_loss /= va_steps
va_recon /= va_steps
va_kl /= va_steps
# Call once per epoch
if epoch in [1, 5, 10, 20, cfg.epochs]:
det_check_two_modal(model, train_loader, cfg.device, "TRAIN")
det_check_two_modal(model, val_loader, cfg.device, "VAL")
print(
f"Epoch {epoch:02d} | beta={beta:.3f} | "
f"Train loss={tr_loss:.4f} (recon={tr_recon:.4f}, kl={tr_kl:.4f}) | "
f"Val loss={va_loss:.4f} (recon={va_recon:.4f}, kl={va_kl:.4f})"
)
# Plots (2-modal)
if epoch in {1, 5, 10, cfg.epochs}:
with torch.no_grad():
batch = next(iter(val_loader))
# DataLoader can return (x1, x2) as tuple or list
x1_plot, x2_plot = batch[0].to(cfg.device), batch[1].to(cfg.device)
mu_x1_plot, mu_x2_plot, _, _ = model(x1_plot, x2_plot, sample=False)
plot_reconstructions(
x1_plot[:8], mu_x1_plot[:8],
os.path.join(cfg.out_dir, f"recon_m1_epoch{epoch}.png")
)
plot_reconstructions(
x2_plot[:8], mu_x2_plot[:8],
os.path.join(cfg.out_dir, f"recon_m2_epoch{epoch}.png")
)
model_path = os.path.join(cfg.out_dir, "ts_vae.pt")
model_dir = os.path.dirname(model_path)
if model_dir and not os.path.exists(model_dir):
os.makedirs(model_dir, exist_ok=True)
torch.save(model.state_dict(), model_path)
print("Saved to:", model_path)
# %%
main()