Spaces:
Running
on
Zero
Running
on
Zero
| from random import random | |
| from typing import Literal | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from vector_quantize_pytorch import FSQ | |
| from zcodec.models.components.transformer import TransformerBlock | |
| class AdaLayerNormScale(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| self.linear = nn.Linear(dim, dim * 3) | |
| self.norm = nn.LayerNorm(dim, elementwise_affine=False) | |
| def forward(self, x, c): | |
| x = self.norm(x) | |
| scale, bias, gate = self.linear(F.silu(c)).chunk(3, dim=1) | |
| shape = x.shape[0] + [1] * (x.dim() - 2) + x.shape[-1] | |
| scale, bias, gate = map(lambda x: x.view(*shape), (scale, bias, gate)) | |
| x = x * (1 + scale) + bias | |
| return x, gate | |
| class GaussianFourierTimeEmbedding(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.randn(dim), requires_grad=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x[:, None] * self.weight[None, :] * 2 * torch.pi | |
| x = torch.cat((torch.sin(x), torch.cos(x)), dim=1) | |
| return x | |
| LAYER_FACTORIES = {} | |
| def register_flow_layer_factory(name): | |
| def decorator(fn): | |
| LAYER_FACTORIES[name] = fn | |
| return fn | |
| return decorator | |
| def SimpleConvNextFactory(dim: int, i: int, n_layer: int, is_causal: bool = False): | |
| return ConvNeXtBlock(dim, elementwise_affine_ln=False, is_causal=is_causal) | |
| def MLP(dim: int, i: int, n_layer: int, is_causal: bool = False): | |
| return AdaLNMLP(dim) | |
| def SelfAttentionTransformer(dim: int, i: int, n_layer: int, is_causal: bool = False): | |
| return TransformerBlock(dim, 64, elementwise_affine_ln=False, is_causal=is_causal) | |
| def init_weights(m: nn.Module): | |
| if isinstance(m, (nn.Conv1d, nn.Linear)): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| def init_adaln_weights(m: nn.Module): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| nn.init.zeros_(m.bias) | |
| def modulate(x, scale, shift): | |
| return x * (1 + scale[:, None]) + shift[:, None] | |
| class AdaLNFlowPredictor(nn.Module): | |
| def __init__( | |
| self, | |
| feat_dim: int, | |
| dim: int, | |
| n_layer: int, | |
| layer_factory: str, | |
| cond_dim: int | None = None, | |
| is_causal: bool = False, | |
| ): | |
| super().__init__() | |
| layer_factory = LAYER_FACTORIES[layer_factory] | |
| self.layers = nn.ModuleList( | |
| [ | |
| layer_factory(dim, i, n_layer, is_causal=is_causal) | |
| for i in range(n_layer) | |
| ] | |
| ) | |
| if cond_dim is None: | |
| cond_dim = feat_dim | |
| self.initial_proj = nn.Linear(feat_dim + cond_dim, dim) | |
| self.adaln_proj = nn.ModuleList([nn.Linear(dim, dim * 3) for _ in self.layers]) | |
| self.final_adaln_proj = nn.Linear(dim, dim * 2) | |
| self.out_proj = nn.Linear(dim, feat_dim) | |
| self.final_norm = nn.LayerNorm(dim, elementwise_affine=False) | |
| self.time_emb = GaussianFourierTimeEmbedding(dim // 2) | |
| self.apply(init_weights) | |
| for l in self.adaln_proj: | |
| init_adaln_weights(l) | |
| init_adaln_weights(self.final_adaln_proj) | |
| def forward( | |
| self, | |
| x_t: torch.Tensor, | |
| x_mu: torch.Tensor, | |
| t: torch.Tensor, | |
| ): | |
| x_t, x_mu = map(lambda x: x.transpose(1, 2), (x_t, x_mu)) | |
| x = self.initial_proj(torch.cat((x_t, x_mu), dim=-1)).transpose(1, 2) | |
| t_emb = self.time_emb(t) | |
| for i, (l, adaln) in enumerate(zip(self.layers, self.adaln_proj)): | |
| scale, shift, gate = F.silu(adaln(t_emb)).chunk(3, dim=1) | |
| x = l(x, scale_shift=(scale, shift), gate=gate) | |
| scale, shift = F.silu(self.final_adaln_proj(t_emb)).chunk(2, dim=1) | |
| x = self.final_norm(x.transpose(1, 2)) | |
| x = modulate(x, scale, shift) | |
| x = self.out_proj(x).transpose(1, 2) | |
| return x | |
| class AdaLNMLP(nn.Module): | |
| def __init__(self, hidden_dim): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, hidden_dim, bias=True), | |
| ) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), nn.Linear(hidden_dim, 4 * hidden_dim, bias=True) | |
| ) | |
| def forward(self, x, scale_shift, gate): | |
| x = x.transpose(-1, -2) | |
| h = modulate(self.in_ln(x), *scale_shift) | |
| h = self.mlp(h) | |
| return (x + gate[:, None] * h).transpose(-1, -2) | |
| class ConvNeXtBlock(nn.Module): | |
| """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. | |
| Args: | |
| dim (int): Number of input channels. | |
| intermediate_dim (int): Dimensionality of the intermediate layer. | |
| layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. | |
| Defaults to None. | |
| adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. | |
| None means non-conditional LayerNorm. Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| intermediate_dim: int | None = None, | |
| layer_scale_init_value: float = 0.0, | |
| elementwise_affine_ln: bool = True, | |
| is_causal: bool = False, | |
| ): | |
| super().__init__() | |
| intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3 | |
| self.dwconv = nn.Conv1d( | |
| dim, dim, kernel_size=7, padding=0 if is_causal else 3, groups=dim | |
| ) # depthwise conv | |
| self.norm = nn.LayerNorm( | |
| dim, eps=1e-6, elementwise_affine=elementwise_affine_ln | |
| ) | |
| self.pwconv1 = nn.Linear( | |
| dim, intermediate_dim | |
| ) # pointwise/1x1 convs, implemented with linear layers | |
| self.act = nn.GELU() | |
| self.pwconv2 = nn.Linear(intermediate_dim, dim) | |
| self.gamma = ( | |
| nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) | |
| if layer_scale_init_value > 0 | |
| else None | |
| ) | |
| self.is_causal = is_causal | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None, | |
| gate: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| residual = x | |
| if self.is_causal: | |
| x = torch.nn.functional.pad(x, (6, 0)) | |
| x = self.dwconv(x) | |
| x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) | |
| x = self.norm(x) | |
| if scale_shift is not None: | |
| scale, shift = scale_shift | |
| x = x * scale[:, None] + shift[:, None] | |
| x = self.pwconv1(x) | |
| x = self.act(x) | |
| x = self.pwconv2(x) | |
| if self.gamma is not None: | |
| x = self.gamma * x | |
| if gate is not None: | |
| x = gate[:, None] * x | |
| x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) | |
| x = residual + x | |
| return x | |
| class ConvNextNet(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_layers: int, | |
| intermediate_dim: int | None = None, | |
| is_causal: bool = False, | |
| ): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| *[ | |
| ConvNeXtBlock(dim, intermediate_dim, is_causal=is_causal) | |
| for _ in range(n_layers) | |
| ] | |
| ) | |
| def forward(self, x): | |
| return self.net(x.transpose(1, 2)).transpose(1, 2) | |
| def convnext_factory(dim, n_layers, is_causal=False): | |
| return ConvNextNet(dim, n_layers, is_causal=is_causal) | |
| def convnextformer_factory( | |
| dim, n_layers, n_convnext_per_transformer_block, is_causal=False | |
| ): | |
| layers = [] | |
| for i in range(0, n_layers, n_convnext_per_transformer_block + 1): | |
| layers.append( | |
| ConvNextNet(dim, n_convnext_per_transformer_block, is_causal=is_causal) | |
| ) | |
| layers.append(TransformerBlock(dim, 64, is_causal=is_causal)) | |
| return nn.Sequential(*layers) | |
| class AutoEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| feat_dim: int, | |
| hidden_dim: int, | |
| num_layers: int, | |
| net_factory: Literal["convnext", "convnextformer_decoder", "convnextformer"], | |
| out_dim: int | None = None, | |
| convnextformer_num_conv_per_transformer: int = 3, | |
| causal_transformer: bool = False, | |
| bottleneck_size: int | None = None, | |
| vae: bool = False, | |
| is_causal: bool = False, | |
| ): | |
| super().__init__() | |
| self.embed = nn.Linear(feat_dim, hidden_dim) | |
| if out_dim is None: | |
| out_dim = feat_dim | |
| self.unembed = nn.Linear(hidden_dim, out_dim) | |
| if net_factory == "convnext": | |
| self.encoder_net = convnext_factory( | |
| hidden_dim, num_layers, is_causal=is_causal | |
| ) | |
| self.decoder_net = convnext_factory( | |
| hidden_dim, num_layers, is_causal=is_causal | |
| ) | |
| elif net_factory == "convnextformer_decoder": | |
| self.encoder_net = convnext_factory( | |
| hidden_dim, num_layers, is_causal=is_causal | |
| ) | |
| self.decoder_net = convnextformer_factory( | |
| hidden_dim, | |
| num_layers, | |
| convnextformer_num_conv_per_transformer, | |
| is_causal=is_causal, | |
| ) | |
| elif net_factory == "convnextformer": | |
| self.encoder_net = convnextformer_factory( | |
| hidden_dim, | |
| num_layers, | |
| convnextformer_num_conv_per_transformer, | |
| is_causal=is_causal, | |
| ) | |
| self.decoder_net = convnextformer_factory( | |
| hidden_dim, | |
| num_layers, | |
| convnextformer_num_conv_per_transformer, | |
| is_causal=is_causal, | |
| ) | |
| self.bottleneck = ( | |
| nn.Linear(hidden_dim, bottleneck_size * (1 + vae)) | |
| if bottleneck_size is not None | |
| else nn.Identity() | |
| ) | |
| self.unbottleneck = ( | |
| nn.Linear(bottleneck_size, hidden_dim) | |
| if bottleneck_size is not None | |
| else nn.Identity() | |
| ) | |
| self.vae = vae | |
| def reparameterize( | |
| self, | |
| mu: torch.Tensor, | |
| logvar: torch.Tensor, | |
| deterministic: bool = False, | |
| drop_vae_rate: float = 0.0, | |
| ) -> torch.Tensor: | |
| logvar = torch.clamp(logvar, -30.0, 20.0) | |
| std = torch.exp(0.5 * logvar) | |
| if drop_vae_rate > 0.0: | |
| to_drop = torch.rand(std.shape[0], device=std.device) < drop_vae_rate | |
| eps = torch.randn_like(std) | |
| eps[to_drop] = 0.0 | |
| else: | |
| if deterministic: | |
| eps = torch.zeros_like(std) | |
| else: | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def kl_divergence(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: | |
| kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) | |
| return kl.sum(dim=-1).mean() | |
| def forward(self, x: torch.Tensor, drop_vae_rate: float = 0.0) -> torch.Tensor: | |
| # Encode | |
| x = self.embed(x) | |
| x = self.encoder_net(x) | |
| x = self.bottleneck(x) | |
| if self.vae: | |
| mu, logvar = x.chunk(2, dim=-1) | |
| loss = { | |
| "kl_div": self.kl_divergence(mu, logvar), | |
| "_mu_mean": mu.mean(), | |
| "_mu_std": mu.std(), | |
| "_logvar_mean": logvar.mean(), | |
| "_logvar_std": logvar.std(), | |
| } | |
| x = self.reparameterize( | |
| mu, | |
| logvar, | |
| drop_vae_rate=drop_vae_rate, | |
| ) | |
| else: | |
| loss = {} | |
| # Decode | |
| x = self.unbottleneck(x) | |
| x = self.decoder_net(x) | |
| x = self.unembed(x) | |
| return x, loss | |
| def encode(self, x: torch.Tensor, deterministic: bool = False): | |
| x = self.embed(x) | |
| x = self.encoder_net(x) | |
| x = self.bottleneck(x) | |
| if self.vae: | |
| x = self.reparameterize(*x.chunk(2, dim=-1), deterministic=deterministic) | |
| return x | |
| def decode( | |
| self, | |
| latent: torch.Tensor | None = None, | |
| ): | |
| x = self.unbottleneck(latent) | |
| x = self.decoder_net(x) | |
| x = self.unembed(x) | |
| return x | |