Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import math | |
| import os | |
| import sys | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import torch | |
| from safetensors.torch import load_file | |
| from torch import nn | |
| from torchdyn.core import NeuralODE | |
| from .modules import AdaLNFlowPredictor, AutoEncoder | |
| def suppress_stdout(): | |
| original_stdout = sys.stdout | |
| try: | |
| sys.stdout = open(os.devnull, "w") | |
| yield | |
| finally: | |
| sys.stdout.close() | |
| sys.stdout = original_stdout | |
| def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): | |
| def lr_lambda(step): | |
| if step < warmup_steps: | |
| return step / max(1, warmup_steps) | |
| progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) | |
| cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) | |
| return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr | |
| return lr_lambda | |
| class PatchVAEConfig: | |
| latent_dim: int | |
| hidden_dim: int | |
| latent_scaling: tuple[list[float], list[float]] | None | |
| flow_factory: str | |
| num_flow_layers: int | |
| autoencoder_factory: str | |
| num_autoencoder_layers: int | |
| convnextformer_num_conv_per_transformer: int = 3 | |
| wavvae_path: str | None = None | |
| fsq_levels: list[int] | None = None | |
| bottleneck_size: int | None = None | |
| latent_stride: int = 2 | |
| vae: bool = False | |
| causal_transformer: bool = False | |
| cond_dim: int | None = None | |
| is_causal: bool = False | |
| class PatchVAE(nn.Module): | |
| def __init__(self, cfg: PatchVAEConfig): | |
| super().__init__() | |
| self.flow_net = AdaLNFlowPredictor( | |
| feat_dim=cfg.latent_dim * cfg.latent_stride, | |
| dim=cfg.hidden_dim, | |
| n_layer=cfg.num_flow_layers, | |
| layer_factory=cfg.flow_factory, | |
| cond_dim=cfg.cond_dim, | |
| is_causal=cfg.is_causal, | |
| ) | |
| self.autoencoder = AutoEncoder( | |
| cfg.latent_dim * cfg.latent_stride, | |
| cfg.hidden_dim, | |
| cfg.num_autoencoder_layers, | |
| cfg.autoencoder_factory, | |
| out_dim=cfg.cond_dim, | |
| vae=cfg.vae, | |
| bottleneck_size=cfg.bottleneck_size, | |
| convnextformer_num_conv_per_transformer=cfg.convnextformer_num_conv_per_transformer, | |
| is_causal=cfg.is_causal, | |
| ) | |
| if cfg.latent_scaling is not None: | |
| mean, std = cfg.latent_scaling | |
| self.register_buffer("mean_latent_scaling", torch.tensor(mean)) | |
| self.register_buffer("std_latent_scaling", torch.tensor(std)) | |
| else: | |
| self.mean_latent_scaling = None | |
| self.std_latent_scaling = None | |
| self.latent_stride = cfg.latent_stride | |
| self.latent_dim = cfg.latent_dim | |
| self.wavvae = None | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name_or_path: str, | |
| map_location: str = "cpu", | |
| ): | |
| if Path(pretrained_model_name_or_path).exists(): | |
| path = pretrained_model_name_or_path | |
| else: | |
| from huggingface_hub import snapshot_download | |
| path = snapshot_download(pretrained_model_name_or_path) | |
| with open(Path(path) / "config.json", "r") as f: | |
| config = json.load(f) | |
| config = PatchVAEConfig(**config) | |
| model = cls(config).to(map_location) | |
| state_dict = load_file( | |
| Path(path) / "model.st", | |
| device=map_location, | |
| ) | |
| model.load_state_dict(state_dict, assign=True) | |
| if config.wavvae_path is not None: | |
| from .. import WavVAE | |
| model.wavvae = WavVAE.from_pretrained(config.wavvae_path).to(map_location) | |
| else: | |
| model.wavvae = None | |
| return model | |
| def wavvae_from_pretrained( | |
| self, | |
| pretrained_model_name_or_path: str, | |
| *args, | |
| **kwargs, | |
| ): | |
| from .. import WavVAE | |
| self.wavvae = WavVAE.from_pretrained( | |
| pretrained_model_name_or_path, | |
| *args, | |
| **kwargs, | |
| ) | |
| def encode(self, wav: torch.Tensor): | |
| assert self.wavvae is not None, ( | |
| "please provide WavVAE model to encode from waveform" | |
| ) | |
| z = self.wavvae.encode(wav) | |
| zz = self.encode_patch(z) | |
| return zz | |
| def decode(self, patchvae_latent: torch.Tensor, **kwargs): | |
| assert self.wavvae is not None, ( | |
| "please provide WavVAE model to decode to waveform" | |
| ) | |
| z = self.decode_patch(patchvae_latent, **kwargs) | |
| wav = self.wavvae.decode(z) | |
| return wav | |
| def normalize_z(self, z: torch.Tensor): | |
| if self.mean_latent_scaling is not None: | |
| z = (z - self.mean_latent_scaling) / self.std_latent_scaling | |
| return z | |
| def denormalize_z(self, z: torch.Tensor): | |
| if self.std_latent_scaling is not None: | |
| z = z * self.std_latent_scaling + self.mean_latent_scaling | |
| return z | |
| def encode_patch(self, z: torch.Tensor, deterministic: bool = False): | |
| B, T, D = z.shape | |
| z = self.normalize_z(z) | |
| if self.latent_stride > 1: | |
| z = z[:, : T - T % self.latent_stride] | |
| z = z.reshape(B, T // self.latent_stride, D * self.latent_stride) | |
| return self.autoencoder.encode(z, deterministic=deterministic) | |
| def decode_patch( | |
| self, | |
| latent: torch.Tensor, | |
| cfg: float = 2.0, | |
| num_steps: int = 15, | |
| solver: str = "euler", | |
| sensitivity: str = "adjoint", | |
| temperature: float = 1.0, | |
| **kwargs, | |
| ): | |
| with torch.no_grad(): | |
| z_cond = self.autoencoder.decode(latent).transpose(1, 2) | |
| if cfg == 1.0: | |
| def solver_fn(t, Xt, *args, **kwargs): | |
| flow = self.flow_net(Xt, z_cond, t.unsqueeze(0)) | |
| return flow | |
| else: | |
| z_cond_uncond = torch.cat((z_cond, torch.zeros_like(z_cond)), dim=0) | |
| def solver_fn(t, Xt, *args, **kwargs): | |
| flow = self.flow_net( | |
| Xt.repeat(2, 1, 1), z_cond_uncond, t.unsqueeze(0) | |
| ) | |
| cond, uncond = flow.chunk(2, dim=0) | |
| return uncond + cfg * (cond - uncond) | |
| with suppress_stdout(): | |
| node_ = NeuralODE( | |
| solver_fn, | |
| solver=solver, | |
| sensitivity=sensitivity, | |
| **kwargs, | |
| ) | |
| t_span = torch.linspace(0, 1, num_steps + 1, device=z_cond.device) | |
| patch_dim = self.latent_dim * self.latent_stride | |
| x0 = torch.randn( | |
| z_cond.shape[0], | |
| patch_dim, | |
| z_cond.shape[2], | |
| device=z_cond.device, | |
| ) | |
| traj = node_.trajectory( | |
| x0 * temperature, | |
| t_span=t_span, | |
| ) | |
| y_hat = traj[-1] | |
| y_hat = y_hat.transpose(1, 2) | |
| B, T, D = y_hat.shape | |
| y_hat = y_hat.reshape(B, T * self.latent_stride, D // self.latent_stride) | |
| y_hat = self.denormalize_z(y_hat) | |
| return y_hat | |
| def forward( | |
| self, | |
| z: torch.Tensor, | |
| t: torch.Tensor, | |
| drop_cond_rate: float = 0.0, | |
| drop_vae_rate: float = 0.0, | |
| sigma: float = 1e-4, | |
| ): | |
| z = self.normalize_z(z) | |
| B, T, D = z.shape | |
| if self.latent_stride > 1: | |
| z = z.reshape(B, T // self.latent_stride, D * self.latent_stride) | |
| prior, ae_loss = self.autoencoder(z, drop_vae_rate=drop_vae_rate) | |
| if drop_cond_rate > 0.0: | |
| to_drop = torch.rand(prior.shape[0], device=prior.device) < drop_cond_rate | |
| prior[to_drop] = 0.0 | |
| x0 = torch.randn_like(z) | |
| x1 = z | |
| flow_target = x1 - (1 - sigma) * x0 | |
| alpha = (1 - (1 - sigma) * t).view(-1, 1, 1) | |
| xt = alpha * x0 + t.view(-1, 1, 1) * x1 | |
| pred = self.flow_net( | |
| xt.transpose(1, 2), | |
| prior.transpose(1, 2), | |
| t, | |
| ) | |
| flow_loss = nn.functional.mse_loss(flow_target.transpose(1, 2), pred) | |
| return flow_loss, ae_loss, prior | |