import json import math import os import sys from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path import torch from safetensors.torch import load_file from torch import nn from torchdyn.core import NeuralODE from .modules import AdaLNFlowPredictor, AutoEncoder @contextmanager def suppress_stdout(): original_stdout = sys.stdout try: sys.stdout = open(os.devnull, "w") yield finally: sys.stdout.close() sys.stdout = original_stdout def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): def lr_lambda(step): if step < warmup_steps: return step / max(1, warmup_steps) progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr return lr_lambda @dataclass class PatchVAEConfig: latent_dim: int hidden_dim: int latent_scaling: tuple[list[float], list[float]] | None flow_factory: str num_flow_layers: int autoencoder_factory: str num_autoencoder_layers: int convnextformer_num_conv_per_transformer: int = 3 wavvae_path: str | None = None fsq_levels: list[int] | None = None bottleneck_size: int | None = None latent_stride: int = 2 vae: bool = False causal_transformer: bool = False cond_dim: int | None = None is_causal: bool = False class PatchVAE(nn.Module): def __init__(self, cfg: PatchVAEConfig): super().__init__() self.flow_net = AdaLNFlowPredictor( feat_dim=cfg.latent_dim * cfg.latent_stride, dim=cfg.hidden_dim, n_layer=cfg.num_flow_layers, layer_factory=cfg.flow_factory, cond_dim=cfg.cond_dim, is_causal=cfg.is_causal, ) self.autoencoder = AutoEncoder( cfg.latent_dim * cfg.latent_stride, cfg.hidden_dim, cfg.num_autoencoder_layers, cfg.autoencoder_factory, out_dim=cfg.cond_dim, vae=cfg.vae, bottleneck_size=cfg.bottleneck_size, convnextformer_num_conv_per_transformer=cfg.convnextformer_num_conv_per_transformer, is_causal=cfg.is_causal, ) if cfg.latent_scaling is not None: mean, std = cfg.latent_scaling self.register_buffer("mean_latent_scaling", torch.tensor(mean)) self.register_buffer("std_latent_scaling", torch.tensor(std)) else: self.mean_latent_scaling = None self.std_latent_scaling = None self.latent_stride = cfg.latent_stride self.latent_dim = cfg.latent_dim self.wavvae = None @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, map_location: str = "cpu", ): if Path(pretrained_model_name_or_path).exists(): path = pretrained_model_name_or_path else: from huggingface_hub import snapshot_download path = snapshot_download(pretrained_model_name_or_path) with open(Path(path) / "config.json", "r") as f: config = json.load(f) config = PatchVAEConfig(**config) model = cls(config).to(map_location) state_dict = load_file( Path(path) / "model.st", device=map_location, ) model.load_state_dict(state_dict, assign=True) if config.wavvae_path is not None: from .. import WavVAE model.wavvae = WavVAE.from_pretrained(config.wavvae_path).to(map_location) else: model.wavvae = None return model def wavvae_from_pretrained( self, pretrained_model_name_or_path: str, *args, **kwargs, ): from .. import WavVAE self.wavvae = WavVAE.from_pretrained( pretrained_model_name_or_path, *args, **kwargs, ) def encode(self, wav: torch.Tensor): assert self.wavvae is not None, ( "please provide WavVAE model to encode from waveform" ) z = self.wavvae.encode(wav) zz = self.encode_patch(z) return zz def decode(self, patchvae_latent: torch.Tensor, **kwargs): assert self.wavvae is not None, ( "please provide WavVAE model to decode to waveform" ) z = self.decode_patch(patchvae_latent, **kwargs) wav = self.wavvae.decode(z) return wav def normalize_z(self, z: torch.Tensor): if self.mean_latent_scaling is not None: z = (z - self.mean_latent_scaling) / self.std_latent_scaling return z def denormalize_z(self, z: torch.Tensor): if self.std_latent_scaling is not None: z = z * self.std_latent_scaling + self.mean_latent_scaling return z def encode_patch(self, z: torch.Tensor, deterministic: bool = False): B, T, D = z.shape z = self.normalize_z(z) if self.latent_stride > 1: z = z[:, : T - T % self.latent_stride] z = z.reshape(B, T // self.latent_stride, D * self.latent_stride) return self.autoencoder.encode(z, deterministic=deterministic) def decode_patch( self, latent: torch.Tensor, cfg: float = 2.0, num_steps: int = 15, solver: str = "euler", sensitivity: str = "adjoint", temperature: float = 1.0, **kwargs, ): with torch.no_grad(): z_cond = self.autoencoder.decode(latent).transpose(1, 2) if cfg == 1.0: def solver_fn(t, Xt, *args, **kwargs): flow = self.flow_net(Xt, z_cond, t.unsqueeze(0)) return flow else: z_cond_uncond = torch.cat((z_cond, torch.zeros_like(z_cond)), dim=0) def solver_fn(t, Xt, *args, **kwargs): flow = self.flow_net( Xt.repeat(2, 1, 1), z_cond_uncond, t.unsqueeze(0) ) cond, uncond = flow.chunk(2, dim=0) return uncond + cfg * (cond - uncond) with suppress_stdout(): node_ = NeuralODE( solver_fn, solver=solver, sensitivity=sensitivity, **kwargs, ) t_span = torch.linspace(0, 1, num_steps + 1, device=z_cond.device) patch_dim = self.latent_dim * self.latent_stride x0 = torch.randn( z_cond.shape[0], patch_dim, z_cond.shape[2], device=z_cond.device, ) traj = node_.trajectory( x0 * temperature, t_span=t_span, ) y_hat = traj[-1] y_hat = y_hat.transpose(1, 2) B, T, D = y_hat.shape y_hat = y_hat.reshape(B, T * self.latent_stride, D // self.latent_stride) y_hat = self.denormalize_z(y_hat) return y_hat def forward( self, z: torch.Tensor, t: torch.Tensor, drop_cond_rate: float = 0.0, drop_vae_rate: float = 0.0, sigma: float = 1e-4, ): z = self.normalize_z(z) B, T, D = z.shape if self.latent_stride > 1: z = z.reshape(B, T // self.latent_stride, D * self.latent_stride) prior, ae_loss = self.autoencoder(z, drop_vae_rate=drop_vae_rate) if drop_cond_rate > 0.0: to_drop = torch.rand(prior.shape[0], device=prior.device) < drop_cond_rate prior[to_drop] = 0.0 x0 = torch.randn_like(z) x1 = z flow_target = x1 - (1 - sigma) * x0 alpha = (1 - (1 - sigma) * t).view(-1, 1, 1) xt = alpha * x0 + t.view(-1, 1, 1) * x1 pred = self.flow_net( xt.transpose(1, 2), prior.transpose(1, 2), t, ) flow_loss = nn.functional.mse_loss(flow_target.transpose(1, 2), pred) return flow_loss, ae_loss, prior