Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
raw
history blame
12.6 kB
from random import random
from typing import Literal
import torch
import torch.nn.functional as F
from torch import nn
from vector_quantize_pytorch import FSQ
from zcodec.models.components.transformer import TransformerBlock
class AdaLayerNormScale(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.linear = nn.Linear(dim, dim * 3)
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
def forward(self, x, c):
x = self.norm(x)
scale, bias, gate = self.linear(F.silu(c)).chunk(3, dim=1)
shape = x.shape[0] + [1] * (x.dim() - 2) + x.shape[-1]
scale, bias, gate = map(lambda x: x.view(*shape), (scale, bias, gate))
x = x * (1 + scale) + bias
return x, gate
class GaussianFourierTimeEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.weight = nn.Parameter(torch.randn(dim), requires_grad=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x[:, None] * self.weight[None, :] * 2 * torch.pi
x = torch.cat((torch.sin(x), torch.cos(x)), dim=1)
return x
LAYER_FACTORIES = {}
def register_flow_layer_factory(name):
def decorator(fn):
LAYER_FACTORIES[name] = fn
return fn
return decorator
@register_flow_layer_factory("convnext")
def SimpleConvNextFactory(dim: int, i: int, n_layer: int, is_causal: bool = False):
return ConvNeXtBlock(dim, elementwise_affine_ln=False, is_causal=is_causal)
@register_flow_layer_factory("mlp")
def MLP(dim: int, i: int, n_layer: int, is_causal: bool = False):
return AdaLNMLP(dim)
@register_flow_layer_factory("sa_transformer")
def SelfAttentionTransformer(dim: int, i: int, n_layer: int, is_causal: bool = False):
return TransformerBlock(dim, 64, elementwise_affine_ln=False, is_causal=is_causal)
def init_weights(m: nn.Module):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def init_adaln_weights(m: nn.Module):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.zeros_(m.bias)
def modulate(x, scale, shift):
return x * (1 + scale[:, None]) + shift[:, None]
class AdaLNFlowPredictor(nn.Module):
def __init__(
self,
feat_dim: int,
dim: int,
n_layer: int,
layer_factory: str,
cond_dim: int | None = None,
is_causal: bool = False,
):
super().__init__()
layer_factory = LAYER_FACTORIES[layer_factory]
self.layers = nn.ModuleList(
[
layer_factory(dim, i, n_layer, is_causal=is_causal)
for i in range(n_layer)
]
)
if cond_dim is None:
cond_dim = feat_dim
self.initial_proj = nn.Linear(feat_dim + cond_dim, dim)
self.adaln_proj = nn.ModuleList([nn.Linear(dim, dim * 3) for _ in self.layers])
self.final_adaln_proj = nn.Linear(dim, dim * 2)
self.out_proj = nn.Linear(dim, feat_dim)
self.final_norm = nn.LayerNorm(dim, elementwise_affine=False)
self.time_emb = GaussianFourierTimeEmbedding(dim // 2)
self.apply(init_weights)
for l in self.adaln_proj:
init_adaln_weights(l)
init_adaln_weights(self.final_adaln_proj)
def forward(
self,
x_t: torch.Tensor,
x_mu: torch.Tensor,
t: torch.Tensor,
):
x_t, x_mu = map(lambda x: x.transpose(1, 2), (x_t, x_mu))
x = self.initial_proj(torch.cat((x_t, x_mu), dim=-1)).transpose(1, 2)
t_emb = self.time_emb(t)
for i, (l, adaln) in enumerate(zip(self.layers, self.adaln_proj)):
scale, shift, gate = F.silu(adaln(t_emb)).chunk(3, dim=1)
x = l(x, scale_shift=(scale, shift), gate=gate)
scale, shift = F.silu(self.final_adaln_proj(t_emb)).chunk(2, dim=1)
x = self.final_norm(x.transpose(1, 2))
x = modulate(x, scale, shift)
x = self.out_proj(x).transpose(1, 2)
return x
class AdaLNMLP(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, bias=True),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim, bias=True),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_dim, 4 * hidden_dim, bias=True)
)
def forward(self, x, scale_shift, gate):
x = x.transpose(-1, -2)
h = modulate(self.in_ln(x), *scale_shift)
h = self.mlp(h)
return (x + gate[:, None] * h).transpose(-1, -2)
class ConvNeXtBlock(nn.Module):
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
Args:
dim (int): Number of input channels.
intermediate_dim (int): Dimensionality of the intermediate layer.
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
Defaults to None.
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
None means non-conditional LayerNorm. Defaults to None.
"""
def __init__(
self,
dim: int,
intermediate_dim: int | None = None,
layer_scale_init_value: float = 0.0,
elementwise_affine_ln: bool = True,
is_causal: bool = False,
):
super().__init__()
intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=0 if is_causal else 3, groups=dim
) # depthwise conv
self.norm = nn.LayerNorm(
dim, eps=1e-6, elementwise_affine=elementwise_affine_ln
)
self.pwconv1 = nn.Linear(
dim, intermediate_dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(intermediate_dim, dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
if layer_scale_init_value > 0
else None
)
self.is_causal = is_causal
def forward(
self,
x: torch.Tensor,
scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
gate: torch.Tensor | None = None,
) -> torch.Tensor:
residual = x
if self.is_causal:
x = torch.nn.functional.pad(x, (6, 0))
x = self.dwconv(x)
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
x = self.norm(x)
if scale_shift is not None:
scale, shift = scale_shift
x = x * scale[:, None] + shift[:, None]
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
if gate is not None:
x = gate[:, None] * x
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
x = residual + x
return x
class ConvNextNet(nn.Module):
def __init__(
self,
dim: int,
n_layers: int,
intermediate_dim: int | None = None,
is_causal: bool = False,
):
super().__init__()
self.net = nn.Sequential(
*[
ConvNeXtBlock(dim, intermediate_dim, is_causal=is_causal)
for _ in range(n_layers)
]
)
def forward(self, x):
return self.net(x.transpose(1, 2)).transpose(1, 2)
def convnext_factory(dim, n_layers, is_causal=False):
return ConvNextNet(dim, n_layers, is_causal=is_causal)
def convnextformer_factory(
dim, n_layers, n_convnext_per_transformer_block, is_causal=False
):
layers = []
for i in range(0, n_layers, n_convnext_per_transformer_block + 1):
layers.append(
ConvNextNet(dim, n_convnext_per_transformer_block, is_causal=is_causal)
)
layers.append(TransformerBlock(dim, 64, is_causal=is_causal))
return nn.Sequential(*layers)
class AutoEncoder(nn.Module):
def __init__(
self,
feat_dim: int,
hidden_dim: int,
num_layers: int,
net_factory: Literal["convnext", "convnextformer_decoder", "convnextformer"],
out_dim: int | None = None,
convnextformer_num_conv_per_transformer: int = 3,
causal_transformer: bool = False,
bottleneck_size: int | None = None,
vae: bool = False,
is_causal: bool = False,
):
super().__init__()
self.embed = nn.Linear(feat_dim, hidden_dim)
if out_dim is None:
out_dim = feat_dim
self.unembed = nn.Linear(hidden_dim, out_dim)
if net_factory == "convnext":
self.encoder_net = convnext_factory(
hidden_dim, num_layers, is_causal=is_causal
)
self.decoder_net = convnext_factory(
hidden_dim, num_layers, is_causal=is_causal
)
elif net_factory == "convnextformer_decoder":
self.encoder_net = convnext_factory(
hidden_dim, num_layers, is_causal=is_causal
)
self.decoder_net = convnextformer_factory(
hidden_dim,
num_layers,
convnextformer_num_conv_per_transformer,
is_causal=is_causal,
)
elif net_factory == "convnextformer":
self.encoder_net = convnextformer_factory(
hidden_dim,
num_layers,
convnextformer_num_conv_per_transformer,
is_causal=is_causal,
)
self.decoder_net = convnextformer_factory(
hidden_dim,
num_layers,
convnextformer_num_conv_per_transformer,
is_causal=is_causal,
)
self.bottleneck = (
nn.Linear(hidden_dim, bottleneck_size * (1 + vae))
if bottleneck_size is not None
else nn.Identity()
)
self.unbottleneck = (
nn.Linear(bottleneck_size, hidden_dim)
if bottleneck_size is not None
else nn.Identity()
)
self.vae = vae
def reparameterize(
self,
mu: torch.Tensor,
logvar: torch.Tensor,
deterministic: bool = False,
drop_vae_rate: float = 0.0,
) -> torch.Tensor:
logvar = torch.clamp(logvar, -30.0, 20.0)
std = torch.exp(0.5 * logvar)
if drop_vae_rate > 0.0:
to_drop = torch.rand(std.shape[0], device=std.device) < drop_vae_rate
eps = torch.randn_like(std)
eps[to_drop] = 0.0
else:
if deterministic:
eps = torch.zeros_like(std)
else:
eps = torch.randn_like(std)
return mu + eps * std
def kl_divergence(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
return kl.sum(dim=-1).mean()
def forward(self, x: torch.Tensor, drop_vae_rate: float = 0.0) -> torch.Tensor:
# Encode
x = self.embed(x)
x = self.encoder_net(x)
x = self.bottleneck(x)
if self.vae:
mu, logvar = x.chunk(2, dim=-1)
loss = {
"kl_div": self.kl_divergence(mu, logvar),
"_mu_mean": mu.mean(),
"_mu_std": mu.std(),
"_logvar_mean": logvar.mean(),
"_logvar_std": logvar.std(),
}
x = self.reparameterize(
mu,
logvar,
drop_vae_rate=drop_vae_rate,
)
else:
loss = {}
# Decode
x = self.unbottleneck(x)
x = self.decoder_net(x)
x = self.unembed(x)
return x, loss
def encode(self, x: torch.Tensor, deterministic: bool = False):
x = self.embed(x)
x = self.encoder_net(x)
x = self.bottleneck(x)
if self.vae:
x = self.reparameterize(*x.chunk(2, dim=-1), deterministic=deterministic)
return x
def decode(
self,
latent: torch.Tensor | None = None,
):
x = self.unbottleneck(latent)
x = self.decoder_net(x)
x = self.unembed(x)
return x