from random import random from typing import Literal import torch import torch.nn.functional as F from torch import nn from vector_quantize_pytorch import FSQ from zcodec.models.components.transformer import TransformerBlock class AdaLayerNormScale(nn.Module): def __init__(self, dim: int): super().__init__() self.linear = nn.Linear(dim, dim * 3) self.norm = nn.LayerNorm(dim, elementwise_affine=False) def forward(self, x, c): x = self.norm(x) scale, bias, gate = self.linear(F.silu(c)).chunk(3, dim=1) shape = x.shape[0] + [1] * (x.dim() - 2) + x.shape[-1] scale, bias, gate = map(lambda x: x.view(*shape), (scale, bias, gate)) x = x * (1 + scale) + bias return x, gate class GaussianFourierTimeEmbedding(nn.Module): def __init__(self, dim: int): super().__init__() self.weight = nn.Parameter(torch.randn(dim), requires_grad=False) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x[:, None] * self.weight[None, :] * 2 * torch.pi x = torch.cat((torch.sin(x), torch.cos(x)), dim=1) return x LAYER_FACTORIES = {} def register_flow_layer_factory(name): def decorator(fn): LAYER_FACTORIES[name] = fn return fn return decorator @register_flow_layer_factory("convnext") def SimpleConvNextFactory(dim: int, i: int, n_layer: int, is_causal: bool = False): return ConvNeXtBlock(dim, elementwise_affine_ln=False, is_causal=is_causal) @register_flow_layer_factory("mlp") def MLP(dim: int, i: int, n_layer: int, is_causal: bool = False): return AdaLNMLP(dim) @register_flow_layer_factory("sa_transformer") def SelfAttentionTransformer(dim: int, i: int, n_layer: int, is_causal: bool = False): return TransformerBlock(dim, 64, elementwise_affine_ln=False, is_causal=is_causal) def init_weights(m: nn.Module): if isinstance(m, (nn.Conv1d, nn.Linear)): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def init_adaln_weights(m: nn.Module): nn.init.trunc_normal_(m.weight, std=0.02) nn.init.zeros_(m.bias) def modulate(x, scale, shift): return x * (1 + scale[:, None]) + shift[:, None] class AdaLNFlowPredictor(nn.Module): def __init__( self, feat_dim: int, dim: int, n_layer: int, layer_factory: str, cond_dim: int | None = None, is_causal: bool = False, ): super().__init__() layer_factory = LAYER_FACTORIES[layer_factory] self.layers = nn.ModuleList( [ layer_factory(dim, i, n_layer, is_causal=is_causal) for i in range(n_layer) ] ) if cond_dim is None: cond_dim = feat_dim self.initial_proj = nn.Linear(feat_dim + cond_dim, dim) self.adaln_proj = nn.ModuleList([nn.Linear(dim, dim * 3) for _ in self.layers]) self.final_adaln_proj = nn.Linear(dim, dim * 2) self.out_proj = nn.Linear(dim, feat_dim) self.final_norm = nn.LayerNorm(dim, elementwise_affine=False) self.time_emb = GaussianFourierTimeEmbedding(dim // 2) self.apply(init_weights) for l in self.adaln_proj: init_adaln_weights(l) init_adaln_weights(self.final_adaln_proj) def forward( self, x_t: torch.Tensor, x_mu: torch.Tensor, t: torch.Tensor, ): x_t, x_mu = map(lambda x: x.transpose(1, 2), (x_t, x_mu)) x = self.initial_proj(torch.cat((x_t, x_mu), dim=-1)).transpose(1, 2) t_emb = self.time_emb(t) for i, (l, adaln) in enumerate(zip(self.layers, self.adaln_proj)): scale, shift, gate = F.silu(adaln(t_emb)).chunk(3, dim=1) x = l(x, scale_shift=(scale, shift), gate=gate) scale, shift = F.silu(self.final_adaln_proj(t_emb)).chunk(2, dim=1) x = self.final_norm(x.transpose(1, 2)) x = modulate(x, scale, shift) x = self.out_proj(x).transpose(1, 2) return x class AdaLNMLP(nn.Module): def __init__(self, hidden_dim): super().__init__() self.hidden_dim = hidden_dim self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False) self.mlp = nn.Sequential( nn.Linear(hidden_dim, hidden_dim, bias=True), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim, bias=True), ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_dim, 4 * hidden_dim, bias=True) ) def forward(self, x, scale_shift, gate): x = x.transpose(-1, -2) h = modulate(self.in_ln(x), *scale_shift) h = self.mlp(h) return (x + gate[:, None] * h).transpose(-1, -2) class ConvNeXtBlock(nn.Module): """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. Args: dim (int): Number of input channels. intermediate_dim (int): Dimensionality of the intermediate layer. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. Defaults to None. adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. None means non-conditional LayerNorm. Defaults to None. """ def __init__( self, dim: int, intermediate_dim: int | None = None, layer_scale_init_value: float = 0.0, elementwise_affine_ln: bool = True, is_causal: bool = False, ): super().__init__() intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3 self.dwconv = nn.Conv1d( dim, dim, kernel_size=7, padding=0 if is_causal else 3, groups=dim ) # depthwise conv self.norm = nn.LayerNorm( dim, eps=1e-6, elementwise_affine=elementwise_affine_ln ) self.pwconv1 = nn.Linear( dim, intermediate_dim ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) self.gamma = ( nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None ) self.is_causal = is_causal def forward( self, x: torch.Tensor, scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None, gate: torch.Tensor | None = None, ) -> torch.Tensor: residual = x if self.is_causal: x = torch.nn.functional.pad(x, (6, 0)) x = self.dwconv(x) x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) x = self.norm(x) if scale_shift is not None: scale, shift = scale_shift x = x * scale[:, None] + shift[:, None] x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x if gate is not None: x = gate[:, None] * x x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) x = residual + x return x class ConvNextNet(nn.Module): def __init__( self, dim: int, n_layers: int, intermediate_dim: int | None = None, is_causal: bool = False, ): super().__init__() self.net = nn.Sequential( *[ ConvNeXtBlock(dim, intermediate_dim, is_causal=is_causal) for _ in range(n_layers) ] ) def forward(self, x): return self.net(x.transpose(1, 2)).transpose(1, 2) def convnext_factory(dim, n_layers, is_causal=False): return ConvNextNet(dim, n_layers, is_causal=is_causal) def convnextformer_factory( dim, n_layers, n_convnext_per_transformer_block, is_causal=False ): layers = [] for i in range(0, n_layers, n_convnext_per_transformer_block + 1): layers.append( ConvNextNet(dim, n_convnext_per_transformer_block, is_causal=is_causal) ) layers.append(TransformerBlock(dim, 64, is_causal=is_causal)) return nn.Sequential(*layers) class AutoEncoder(nn.Module): def __init__( self, feat_dim: int, hidden_dim: int, num_layers: int, net_factory: Literal["convnext", "convnextformer_decoder", "convnextformer"], out_dim: int | None = None, convnextformer_num_conv_per_transformer: int = 3, causal_transformer: bool = False, bottleneck_size: int | None = None, vae: bool = False, is_causal: bool = False, ): super().__init__() self.embed = nn.Linear(feat_dim, hidden_dim) if out_dim is None: out_dim = feat_dim self.unembed = nn.Linear(hidden_dim, out_dim) if net_factory == "convnext": self.encoder_net = convnext_factory( hidden_dim, num_layers, is_causal=is_causal ) self.decoder_net = convnext_factory( hidden_dim, num_layers, is_causal=is_causal ) elif net_factory == "convnextformer_decoder": self.encoder_net = convnext_factory( hidden_dim, num_layers, is_causal=is_causal ) self.decoder_net = convnextformer_factory( hidden_dim, num_layers, convnextformer_num_conv_per_transformer, is_causal=is_causal, ) elif net_factory == "convnextformer": self.encoder_net = convnextformer_factory( hidden_dim, num_layers, convnextformer_num_conv_per_transformer, is_causal=is_causal, ) self.decoder_net = convnextformer_factory( hidden_dim, num_layers, convnextformer_num_conv_per_transformer, is_causal=is_causal, ) self.bottleneck = ( nn.Linear(hidden_dim, bottleneck_size * (1 + vae)) if bottleneck_size is not None else nn.Identity() ) self.unbottleneck = ( nn.Linear(bottleneck_size, hidden_dim) if bottleneck_size is not None else nn.Identity() ) self.vae = vae def reparameterize( self, mu: torch.Tensor, logvar: torch.Tensor, deterministic: bool = False, drop_vae_rate: float = 0.0, ) -> torch.Tensor: logvar = torch.clamp(logvar, -30.0, 20.0) std = torch.exp(0.5 * logvar) if drop_vae_rate > 0.0: to_drop = torch.rand(std.shape[0], device=std.device) < drop_vae_rate eps = torch.randn_like(std) eps[to_drop] = 0.0 else: if deterministic: eps = torch.zeros_like(std) else: eps = torch.randn_like(std) return mu + eps * std def kl_divergence(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) return kl.sum(dim=-1).mean() def forward(self, x: torch.Tensor, drop_vae_rate: float = 0.0) -> torch.Tensor: # Encode x = self.embed(x) x = self.encoder_net(x) x = self.bottleneck(x) if self.vae: mu, logvar = x.chunk(2, dim=-1) loss = { "kl_div": self.kl_divergence(mu, logvar), "_mu_mean": mu.mean(), "_mu_std": mu.std(), "_logvar_mean": logvar.mean(), "_logvar_std": logvar.std(), } x = self.reparameterize( mu, logvar, drop_vae_rate=drop_vae_rate, ) else: loss = {} # Decode x = self.unbottleneck(x) x = self.decoder_net(x) x = self.unembed(x) return x, loss def encode(self, x: torch.Tensor, deterministic: bool = False): x = self.embed(x) x = self.encoder_net(x) x = self.bottleneck(x) if self.vae: x = self.reparameterize(*x.chunk(2, dim=-1), deterministic=deterministic) return x def decode( self, latent: torch.Tensor | None = None, ): x = self.unbottleneck(latent) x = self.decoder_net(x) x = self.unembed(x) return x