diff --git a/codec/__init__.py b/codec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2285f83140d9290b9ad285eabb89efa191432cc --- /dev/null +++ b/codec/__init__.py @@ -0,0 +1,2 @@ +from .train_patchvae import TrainPatchVAE +from .train_wavvae import TrainWavVAE diff --git a/codec/__pycache__/__init__.cpython-312.pyc b/codec/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d164b12040190789bd63d3cdca5def40be95972 Binary files /dev/null and b/codec/__pycache__/__init__.cpython-312.pyc differ diff --git a/codec/__pycache__/train_patchvae.cpython-312.pyc b/codec/__pycache__/train_patchvae.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7b8daeb0fa80c52c7dae20156695344c3552fce Binary files /dev/null and b/codec/__pycache__/train_patchvae.cpython-312.pyc differ diff --git a/codec/__pycache__/train_wavvae.cpython-312.pyc b/codec/__pycache__/train_wavvae.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39cc6c30402f689090304bc838f9c7d1cc2ea3d2 Binary files /dev/null and b/codec/__pycache__/train_wavvae.cpython-312.pyc differ diff --git a/codec/__pycache__/train_zflowae.cpython-312.pyc b/codec/__pycache__/train_zflowae.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e456645ae126417fcb45db517aea7ba12413502 Binary files /dev/null and b/codec/__pycache__/train_zflowae.cpython-312.pyc differ diff --git a/codec/datamodules.py b/codec/datamodules.py new file mode 100755 index 0000000000000000000000000000000000000000..11bb24ea7964b834d4ae19cf774529bb822933ee --- /dev/null +++ b/codec/datamodules.py @@ -0,0 +1,249 @@ +import itertools +import random +import time +from dataclasses import dataclass +from functools import partial +from pathlib import Path + +import numpy as np +import pytorch_lightning as ptl +import torch +import torchaudio +from safetensors.torch import safe_open +from sklearn.model_selection import train_test_split +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Dataset + +from datasets import load_dataset, load_from_disk + + +@dataclass +class WavVAEDataConfig: + filelist_path: str + sampling_rate: int + num_samples: int + batch_size: int + num_workers: int + + +class WavVAEDataModule(ptl.LightningDataModule): + def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig): + super().__init__() + self.train_config = train_params + self.val_config = val_params + + def _get_dataloder(self, cfg: WavVAEDataConfig, train: bool): + dataset = WavVAEDataset(cfg, train=train) + dataloader = DataLoader( + dataset, + batch_size=cfg.batch_size, + num_workers=cfg.num_workers, + shuffle=train, + pin_memory=True, + ) + return dataloader + + def train_dataloader(self) -> DataLoader: + return self._get_dataloder(self.train_config, train=True) + + def val_dataloader(self) -> DataLoader: + return self._get_dataloder(self.val_config, train=False) + + +class WavVAEDataset(Dataset): + def __init__(self, cfg: WavVAEDataConfig, train: bool): + with open(cfg.filelist_path) as f: + self.filelist = f.read().splitlines() + self.sampling_rate = cfg.sampling_rate + self.num_samples = cfg.num_samples + self.train = train + + def __len__(self) -> int: + return len(self.filelist) + + def __getitem__(self, index: int) -> torch.Tensor: + audio_path = self.filelist[index] + y, sr = torchaudio.load(audio_path) + if y.size(0) > 1: + # mix to mono + y = y.mean(dim=0, keepdim=True) + gain = np.random.uniform(-1, -6) if self.train else -3 + y, _ = torchaudio.sox_effects.apply_effects_tensor( + y, sr, [["norm", f"{gain:.2f}"]] + ) + if sr != self.sampling_rate: + y = torchaudio.functional.resample( + y, orig_freq=sr, new_freq=self.sampling_rate + ) + if y.size(-1) < self.num_samples: + pad_length = self.num_samples - y.size(-1) + padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) + y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) + elif self.train: + start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) + y = y[:, start : start + self.num_samples] + else: + # During validation, take always the first segment for determinism + y = y[:, : self.num_samples] + + return y[0] + + +def pad_tensor_list_raw( + tensor_list: list[tuple[torch.Tensor, torch.Tensor]], pad_idx: int = 0 +) -> dict[str, torch.Tensor | None]: + audio, hubert_maybe = zip(*tensor_list) + audio = torch.cat(audio, dim=0) + if hubert_maybe[0] is not None: + hubert_maybe = torch.stack(hubert_maybe, dim=0) + else: + hubert_maybe = None + return {"audio_z": audio, "hubert": hubert_maybe} + + +class SafeTensorDataset(Dataset): + """ + On __getitem__, opens the safetensor, uses get_slice() to inspect shape, + then either drops too-short files (return None) or returns a random subsequence slice. + """ + + def __init__( + self, + file_paths: list[str], + key: str, + hubert_path: str | None = None, + hubert_key: str = "layer_9", + min_length: int = 1, + subseq_length: int | None = None, + ): + self.file_paths = file_paths + self.key = key + self.min_length = min_length + self.subseq_length = subseq_length + self.hubert_path = hubert_path + self.hubert_key = hubert_key + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx: int) -> torch.Tensor | None: + path = self.file_paths[idx] + # open file, get a slice wrapper for full tensor + with safe_open(path, framework="pt") as f: + tensor_slice = f.get_slice(self.key) + Q, N, D = tensor_slice.get_shape() # full shape [K, N] + + # drop too-short + if N < self.min_length: + return None + + L = self.subseq_length or N + if L < N: + # sample random start + start = torch.randint(0, max(1, N - L - 1), ()).item() + start -= start % 2 + # this yields a torch.Tensor of shape [K, L] + seq = tensor_slice[:, start : start + L] + else: + # full length + start = 0 + seq = tensor_slice[:, :] + + if self.hubert_path is not None: + path = Path(self.hubert_path) / Path(path).name + with safe_open(path, framework="pt") as f: + tensor_slice = f.get_slice(self.hubert_key) + hubert_N, hubert_D = tensor_slice.get_shape() # full shape [K, N] + seq_hubert = tensor_slice[start // 2 : start // 2 + L // 2] + return (seq, seq_hubert) + + return (seq, None) + + +class SafeTensorDataModule(ptl.LightningDataModule): + """ + LightningDataModule using raw .safetensors file list + get_slice inside Dataset. + """ + + def __init__( + self, + train_file_list: str, + val_file_list: str | None = None, + hubert_path: str | None = None, + key: str = "audio_z", + hubert_key: str = "layer_9", + val_split: float = 0.1, + batch_size: int = 32, + num_workers: int = 4, + shuffle: bool = True, + seed: int = 1234, + min_length: int = 1, + subseq_length: int | None = None, + ): + super().__init__() + self.train_file_list = train_file_list + self.val_file_list = val_file_list + self.hubert_path = hubert_path + self.key = key + self.val_split = val_split + self.batch_size = batch_size + self.num_workers = num_workers + self.shuffle = shuffle + self.seed = seed + self.min_length = min_length + self.subseq_length = subseq_length + + def setup(self, stage=None): + with open(self.train_file_list, "r") as f: + train_paths = [line.strip() for line in f if line.strip()] + val_paths = None + if self.val_file_list is not None: + with open(self.train_file_list, "r") as f: + val_paths = [line.strip() for line in f if line.strip()] + # Split into train/val + if ( + isinstance(self.val_split, float) + and 0 < self.val_split < 1 + and val_paths is None + ): + train_paths, val_paths = train_test_split( + train_paths, test_size=self.val_split, random_state=self.seed + ) + + self.train_ds = SafeTensorDataset( + train_paths, + key=self.key, + min_length=self.min_length, + subseq_length=self.subseq_length, + hubert_path=self.hubert_path, + ) + self.val_ds = SafeTensorDataset( + val_paths, + key=self.key, + min_length=self.min_length, + subseq_length=self.subseq_length, + ) + + def _collate_fn( + self, batch: list[torch.Tensor | None] + ) -> tuple[torch.Tensor, torch.BoolTensor]: + seqs = [s for s in batch if s is not None] + return pad_tensor_list_raw(seqs, pad_idx=0) + + def train_dataloader(self): + return DataLoader( + self.train_ds, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + collate_fn=self._collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.val_ds, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + collate_fn=self._collate_fn, + ) diff --git a/codec/models/__init__.py b/codec/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a92b4b9a8a5f448b2d6f900eb72af1fef16b4308 --- /dev/null +++ b/codec/models/__init__.py @@ -0,0 +1,2 @@ +from .patchvae.model import PatchVAE, PatchVAEConfig +from .wavvae.model import WavVAE, WavVAEConfig diff --git a/codec/models/__pycache__/__init__.cpython-312.pyc b/codec/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f53ceff8a05990134106404052316c1e5731ec17 Binary files /dev/null and b/codec/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/codec/models/components/__init__.py b/codec/models/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/codec/models/components/__pycache__/__init__.cpython-312.pyc b/codec/models/components/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..909df039a1ac1c806d329d13b98c416ff9f32e17 Binary files /dev/null and b/codec/models/components/__pycache__/__init__.cpython-312.pyc differ diff --git a/codec/models/components/__pycache__/convnext.cpython-312.pyc b/codec/models/components/__pycache__/convnext.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18c6696c356c110c642db2815bb0049c58bf83f3 Binary files /dev/null and b/codec/models/components/__pycache__/convnext.cpython-312.pyc differ diff --git a/codec/models/components/convnext.py b/codec/models/components/convnext.py new file mode 100755 index 0000000000000000000000000000000000000000..8927a3c9f22b2ae22b45b78c41a81b8bb03fea22 --- /dev/null +++ b/codec/models/components/convnext.py @@ -0,0 +1,221 @@ +import torch +from torch import nn + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int | None = None, + layer_scale_init_value: float = 0.0, + elementwise_affine_ln: bool = True, + is_causal: bool = False, + ): + super().__init__() + intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=0 if is_causal else 3, groups=dim + ) # depthwise conv + self.norm = nn.LayerNorm( + dim, eps=1e-6, elementwise_affine=elementwise_affine_ln + ) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.is_causal = is_causal + + def forward( + self, + x: torch.Tensor, + scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None, + gate: torch.Tensor | None = None, + ) -> torch.Tensor: + residual = x + if self.is_causal: + x = torch.nn.functional.pad(x, (6, 0)) + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + x = self.norm(x) + if scale_shift is not None: + scale, shift = scale_shift + x = x * scale[:, None] + shift[:, None] + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + if gate is not None: + x = gate[:, None] * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class ConvNextNet(nn.Module): + def __init__(self, n_layers, dim, intermediate_dim: int | None = None): + super().__init__() + self.net = nn.Sequential( + *[ + ConvNeXtBlock( + dim, + intermediate_dim, + ) + for _ in range(n_layers) + ] + ) + + def forward(self, x): + return self.net(x) + + +class ConvNextPatchEncoder(nn.Module): + def __init__( + self, + patch_sizes: list[int], + n_layers_per_patch: int, + patch_expansion_factor: float = 1.5, + is_decoder: bool = False, + ): + super().__init__() + patch_to_dim = [] + convnext = [] + for i, patch_size in enumerate(patch_sizes): + in_dim = int((patch_expansion_factor if i > 0 else 1.0) * patch_size) + out_dim = int(patch_expansion_factor * patch_size) + if is_decoder: + in_dim, out_dim = out_dim, in_dim + patch_to_dim.append( + nn.Linear( + in_dim, + out_dim, + ) + ) + convnext += [ + nn.Sequential( + *[ + ConvNeXtBlock(int(patch_size * patch_expansion_factor)) + for _ in range(n_layers_per_patch) + ] + ) + ] + self.is_decoder = is_decoder + self.patch_sizes = patch_sizes + self.patch_expansion_factor = patch_expansion_factor + self.patch_to_dim = nn.ModuleList(patch_to_dim) + self.convnext = nn.ModuleList(convnext) + + def forward(self, x): + if self.is_decoder: + for i, patch_size in reversed(list(enumerate(self.patch_sizes))): + B, P, N = x.shape + patch_expansion_factor_maybe = ( + self.patch_expansion_factor if i > 0 else 1.0 + ) + x = x.reshape(B, int(patch_size * self.patch_expansion_factor), -1) + x = self.convnext[i](x) + x = self.patch_to_dim[i](x.transpose(1, 2)).transpose(1, 2) + else: + for i, patch_size in enumerate(self.patch_sizes): + B, P, N = x.shape + patch_expansion_factor_maybe = ( + self.patch_expansion_factor if i > 0 else 1.0 + ) + x = x.reshape(B, int(patch_size * patch_expansion_factor_maybe), -1) + x = self.patch_to_dim[i](x.transpose(1, 2)).transpose(1, 2) + x = self.convnext[i](x) + return x + + +class ConvNextEncoder(nn.Module): + def __init__( + self, + in_dim: int, + dim: int, + n_layers: int, + intermediate_dim: int | None = None, + stride: int = 1, + ): + super().__init__() + self.in_proj = nn.Linear(in_dim, dim) + if stride > 1: + self.stride = nn.Conv1d( + in_channels=dim, + out_channels=dim, + kernel_size=(stride * 2) + 1, + stride=stride, + padding=stride // 2, + ) + else: + self.stride = nn.Identity() + self.net = ConvNextNet(n_layers, dim, intermediate_dim) + + def forward(self, x): + x = self.in_proj(x.transpose(1, 2)).transpose(1, 2) + x = self.stride(x) + return self.net(x) + + +class ConvNextDecoder(nn.Module): + def __init__( + self, + out_dim: int, + dim: int, + n_layers: int, + intermediate_dim: int | None = None, + stride: int = 1, + stride_position: str = "before", + ): + super().__init__() + self.out_proj = nn.Linear(dim, out_dim) + if stride > 1: + self.stride = nn.ConvTranspose1d( + in_channels=dim, + out_channels=dim, + kernel_size=(stride * 2) + 1, + stride=stride, + padding=stride // 2, + output_padding=stride // 2, + ) + else: + self.stride = nn.Identity() + self.stride_position = stride_position + + self.net = ConvNextNet(n_layers, dim, intermediate_dim) + + def forward(self, x): + if self.stride_position == "before": + x = self.stride(x) + x = self.net(x) + if self.stride_position == "after": + x = self.stride(x) + return self.out_proj(x.transpose(1, 2)).transpose(1, 2) + + +class SwiGLU(nn.Module): + def __init__(self, d_model: int, ffn_expansion_factor: int = 4): + super().__init__() + self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2) + self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model) + + def forward(self, x): + gate, x = self.p_in(x).chunk(2, dim=-1) + return self.p_out(nn.functional.silu(gate) * x) diff --git a/codec/models/components/transformer.py b/codec/models/components/transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..23b8fee796c20881ea2454c80c23a45dd45c54b6 --- /dev/null +++ b/codec/models/components/transformer.py @@ -0,0 +1,224 @@ +from typing import Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb + + +class LocalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + heads: int, + window_len: int = 32, + rotary: bool = True, + is_causal: bool = False, + ): + super().__init__() + self.heads = heads + assert dim % heads == 0, "dim must be divisible by heads" + self.qkv = nn.Linear(dim, 3 * dim) + self.o = nn.Linear(dim, dim) + self.rotary = RotaryEmbedding((dim // heads) // 2) if rotary else None + self.is_causal = is_causal + self.window_len = window_len + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + cache: Optional[Dict[int, torch.Tensor]] = None, + layer_idx: Optional[int] = None, + time_step: int = 0, + ) -> torch.Tensor: + # x: (batch, seq_len, dim) + b, n, dim = x.shape + b, t_len, hd = x.shape + pad_len = (self.window_len - t_len % self.window_len) % self.window_len + padded_x = torch.nn.functional.pad(x, (0, 0, 0, pad_len)) # pad on time dim + mask = torch.ones(t_len, dtype=torch.bool, device=x.device) + mask = torch.nn.functional.pad( + mask, (0, pad_len), value=False + ) # False = masked + mask = mask.expand(b, -1) # [b, padded_len] + mask = rearrange(mask, "b (w n) -> b n 1 1 w", w=self.window_len) + qkv = self.qkv(padded_x).chunk(3, dim=-1) + q, k, v = [ + rearrange(t, "b (w n) (h d) -> b n h w d", h=self.heads, w=self.window_len) + for t in qkv + ] + if cache is not None: + assert layer_idx is not None, "layer_idx must be set when using cache" + cache[layer_idx]["k"] = torch.cat([cache[layer_idx]["k"], k], dim=2) + cache[layer_idx]["v"] = torch.cat([cache[layer_idx]["v"], v], dim=2) + k, v = cache[layer_idx]["k"], cache[layer_idx]["v"] + + # apply rotary embeddings + if self.rotary is not None: + if pos is not None: + rot = self.rotary(pos) # (b,1,n,head_dim) + q = apply_rotary_emb(rot, q) + k = apply_rotary_emb(rot, k) + else: + q = self.rotary.rotate_queries_or_keys(q, offset=time_step) + k = self.rotary.rotate_queries_or_keys(k, offset=time_step) + + # scaled dot-product attention + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None if self.is_causal else mask, + is_causal=self.is_causal, + ) + y = rearrange(y, "b n h w d -> b (w n) (h d)") + y = self.o(y) + y = y[:, :t_len] + return y + + +class SelfAttention(nn.Module): + def __init__( + self, dim: int, heads: int, rotary: bool = True, is_causal: bool = False + ): + super().__init__() + self.heads = heads + assert dim % heads == 0, "dim must be divisible by heads" + self.qkv = nn.Linear(dim, 3 * dim) + self.o = nn.Linear(dim, dim) + self.rotary = RotaryEmbedding((dim // heads) // 2) if rotary else None + self.is_causal = is_causal + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + cache: Optional[Dict[int, torch.Tensor]] = None, + layer_idx: Optional[int] = None, + time_step: int = 0, + ) -> torch.Tensor: + # x: (batch, seq_len, dim) + b, n, dim = x.shape + b, t_len, hd = x.shape + pad_len = (32 - t_len % 32) % 32 + padded_x = torch.nn.functional.pad(x, (0, 0, 0, pad_len)) # pad on time dim + mask = torch.ones(t_len, dtype=torch.bool, device=x.device) + mask = torch.nn.functional.pad( + mask, (0, pad_len), value=False + ) # False = masked + mask = mask.expand(b, -1) # [b, padded_len] + mask = rearrange(mask, "b (w n) -> b n 1 1 w", w=32) + qkv = self.qkv(padded_x).chunk(3, dim=-1) + q, k, v = [ + rearrange(t, "b (w n) (h d) -> b n h w d", h=self.heads, w=32) for t in qkv + ] + # caching for fast autoregressive + if cache is not None: + assert layer_idx is not None, "layer_idx must be set when using cache" + # append new keys/values + cache[layer_idx]["k"] = torch.cat([cache[layer_idx]["k"], k], dim=2) + cache[layer_idx]["v"] = torch.cat([cache[layer_idx]["v"], v], dim=2) + k, v = cache[layer_idx]["k"], cache[layer_idx]["v"] + + # apply rotary embeddings + if self.rotary is not None: + if pos is not None: + rot = self.rotary(pos) # .unsqueeze(1) # (b,1,n,head_dim) + q = apply_rotary_emb(rot, q) + k = apply_rotary_emb(rot, k) + else: + q = self.rotary.rotate_queries_or_keys(q, offset=time_step) + k = self.rotary.rotate_queries_or_keys(k, offset=time_step) + + # scaled dot-product attention + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None if self.is_causal else mask, + is_causal=self.is_causal, + ) + y = rearrange(y, "b n h w d -> b (w n) (h d)") + y = self.o(y) + y = y[:, :t_len] + return y + + +class SwiGLU(nn.Module): + def __init__(self, d_model: int): + super().__init__() + hidden = d_model * 4 // 3 + self.p_in = nn.Linear(d_model, hidden * 2) + self.p_out = nn.Linear(hidden, d_model) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate, data = self.p_in(x).chunk(2, dim=-1) + return self.p_out(F.silu(gate) * data) + + +class TransformerBlock(nn.Module): + """ + Transformer block using custom SelfAttention and SwiGLU FFN. + + Args: + dim: embedding dimension + heads: number of attention heads + rotary: whether to use rotary embeddings + is_causal: whether to apply causal masking + """ + + def __init__( + self, + dim: int, + head_size: int, + rotary: bool = True, + is_causal: bool = False, + elementwise_affine_ln: bool = True, + ): + super().__init__() + assert dim % head_size == 0 + heads = dim // head_size + self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine_ln) + self.attn = LocalSelfAttention(dim, heads, rotary=rotary, is_causal=is_causal) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine_ln) + self.ffn = SwiGLU(dim) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + cache: Optional[Dict[int, Dict[str, torch.Tensor]]] = None, + layer_idx: Optional[int] = None, + scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None, + gate: torch.Tensor = None, + time_step: int = 0, + ) -> torch.Tensor: + # Self-attention block + norm1_x = self.norm1(x) + if scale_shift is not None: + scale, shift = scale_shift + norm1_x = norm1_x * scale[:, None] + shift[:, None] + + attn_out = self.attn( + norm1_x, + mask=mask, + pos=pos, + cache=cache, + layer_idx=layer_idx, + time_step=time_step, + ) + x = x + attn_out + + norm2_x = self.norm2(x) + if gate is not None: + norm2_x = gate[:, None] * norm2_x + + # Feedforward block + ffn_out = self.ffn(norm2_x) + x = x + ffn_out + return x diff --git a/codec/models/pardi_tokenizer.py b/codec/models/pardi_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e72b9d344176b4f73eaed42ef1e35c7653c3b89a --- /dev/null +++ b/codec/models/pardi_tokenizer.py @@ -0,0 +1,10 @@ +import torch + +from zcodec.models import WavVAE, ZFlowAutoEncoder +from zcodec.models.wavvae.model import WavVAEConfig +from zcodec.models.zflowae.model import ZFlowAutoEncoderConfig + + +class PardiTokenizer(nn.Module): + def __init__(self, wavvae_cfg: WavVAEConfig, zflowae_cfg: ZFlowAutoEncoderConfig): + diff --git a/codec/models/patchvae/__pycache__/model.cpython-312.pyc b/codec/models/patchvae/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65688e4a2e0456dada8c1ba35f91cf9d5e0719dd Binary files /dev/null and b/codec/models/patchvae/__pycache__/model.cpython-312.pyc differ diff --git a/codec/models/patchvae/__pycache__/modules.cpython-312.pyc b/codec/models/patchvae/__pycache__/modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9e0f456209503ae551466c3d2791be1c461c131 Binary files /dev/null and b/codec/models/patchvae/__pycache__/modules.cpython-312.pyc differ diff --git a/codec/models/patchvae/model.py b/codec/models/patchvae/model.py new file mode 100755 index 0000000000000000000000000000000000000000..a67f863fe76e8a398de9eaed745fdac757d590d6 --- /dev/null +++ b/codec/models/patchvae/model.py @@ -0,0 +1,262 @@ +import json +import math +import os +import sys +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path + +import torch +from safetensors.torch import load_file +from torch import nn +from torchdyn.core import NeuralODE + +from .modules import AdaLNFlowPredictor, AutoEncoder + + +@contextmanager +def suppress_stdout(): + original_stdout = sys.stdout + try: + sys.stdout = open(os.devnull, "w") + yield + finally: + sys.stdout.close() + sys.stdout = original_stdout + + +def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): + def lr_lambda(step): + if step < warmup_steps: + return step / max(1, warmup_steps) + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) + return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr + + return lr_lambda + + +@dataclass +class PatchVAEConfig: + latent_dim: int + hidden_dim: int + latent_scaling: tuple[list[float], list[float]] | None + flow_factory: str + num_flow_layers: int + autoencoder_factory: str + num_autoencoder_layers: int + convnextformer_num_conv_per_transformer: int = 3 + wavvae_path: str | None = None + fsq_levels: list[int] | None = None + bottleneck_size: int | None = None + latent_stride: int = 2 + vae: bool = False + causal_transformer: bool = False + cond_dim: int | None = None + is_causal: bool = False + + +class PatchVAE(nn.Module): + def __init__(self, cfg: PatchVAEConfig): + super().__init__() + self.flow_net = AdaLNFlowPredictor( + feat_dim=cfg.latent_dim * cfg.latent_stride, + dim=cfg.hidden_dim, + n_layer=cfg.num_flow_layers, + layer_factory=cfg.flow_factory, + cond_dim=cfg.cond_dim, + is_causal=cfg.is_causal, + ) + self.autoencoder = AutoEncoder( + cfg.latent_dim * cfg.latent_stride, + cfg.hidden_dim, + cfg.num_autoencoder_layers, + cfg.autoencoder_factory, + out_dim=cfg.cond_dim, + vae=cfg.vae, + bottleneck_size=cfg.bottleneck_size, + convnextformer_num_conv_per_transformer=cfg.convnextformer_num_conv_per_transformer, + is_causal=cfg.is_causal, + ) + if cfg.latent_scaling is not None: + mean, std = cfg.latent_scaling + self.register_buffer("mean_latent_scaling", torch.tensor(mean)) + self.register_buffer("std_latent_scaling", torch.tensor(std)) + else: + self.mean_latent_scaling = None + self.std_latent_scaling = None + + self.latent_stride = cfg.latent_stride + self.latent_dim = cfg.latent_dim + self.wavvae = None + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + map_location: str = "cpu", + ): + if Path(pretrained_model_name_or_path).exists(): + path = pretrained_model_name_or_path + else: + from huggingface_hub import snapshot_download + + path = snapshot_download(pretrained_model_name_or_path) + + with open(Path(path) / "config.json", "r") as f: + config = json.load(f) + config = PatchVAEConfig(**config) + model = cls(config).to(map_location) + state_dict = load_file( + Path(path) / "model.st", + device=map_location, + ) + model.load_state_dict(state_dict, assign=True) + if config.wavvae_path is not None: + from .. import WavVAE + + model.wavvae = WavVAE.from_pretrained(config.wavvae_path).to(map_location) + else: + model.wavvae = None + + return model + + def wavvae_from_pretrained( + self, + pretrained_model_name_or_path: str, + *args, + **kwargs, + ): + from .. import WavVAE + + self.wavvae = WavVAE.from_pretrained( + pretrained_model_name_or_path, + *args, + **kwargs, + ) + + def encode(self, wav: torch.Tensor): + assert self.wavvae is not None, ( + "please provide WavVAE model to encode from waveform" + ) + z = self.wavvae.encode(wav) + zz = self.encode_patch(z) + return zz + + def decode(self, patchvae_latent: torch.Tensor, **kwargs): + assert self.wavvae is not None, ( + "please provide WavVAE model to decode to waveform" + ) + z = self.decode_patch(patchvae_latent, **kwargs) + wav = self.wavvae.decode(z) + return wav + + def normalize_z(self, z: torch.Tensor): + if self.mean_latent_scaling is not None: + z = (z - self.mean_latent_scaling) / self.std_latent_scaling + return z + + def denormalize_z(self, z: torch.Tensor): + if self.std_latent_scaling is not None: + z = z * self.std_latent_scaling + self.mean_latent_scaling + return z + + def encode_patch(self, z: torch.Tensor, deterministic: bool = False): + B, T, D = z.shape + z = self.normalize_z(z) + if self.latent_stride > 1: + z = z[:, : T - T % self.latent_stride] + z = z.reshape(B, T // self.latent_stride, D * self.latent_stride) + return self.autoencoder.encode(z, deterministic=deterministic) + + def decode_patch( + self, + latent: torch.Tensor, + cfg: float = 2.0, + num_steps: int = 15, + solver: str = "euler", + sensitivity: str = "adjoint", + temperature: float = 1.0, + **kwargs, + ): + with torch.no_grad(): + z_cond = self.autoencoder.decode(latent).transpose(1, 2) + if cfg == 1.0: + + def solver_fn(t, Xt, *args, **kwargs): + flow = self.flow_net(Xt, z_cond, t.unsqueeze(0)) + return flow + else: + z_cond_uncond = torch.cat((z_cond, torch.zeros_like(z_cond)), dim=0) + + def solver_fn(t, Xt, *args, **kwargs): + flow = self.flow_net( + Xt.repeat(2, 1, 1), z_cond_uncond, t.unsqueeze(0) + ) + cond, uncond = flow.chunk(2, dim=0) + + return uncond + cfg * (cond - uncond) + + with suppress_stdout(): + node_ = NeuralODE( + solver_fn, + solver=solver, + sensitivity=sensitivity, + **kwargs, + ) + t_span = torch.linspace(0, 1, num_steps + 1, device=z_cond.device) + patch_dim = self.latent_dim * self.latent_stride + x0 = torch.randn( + z_cond.shape[0], + patch_dim, + z_cond.shape[2], + device=z_cond.device, + ) + traj = node_.trajectory( + x0 * temperature, + t_span=t_span, + ) + + y_hat = traj[-1] + y_hat = y_hat.transpose(1, 2) + B, T, D = y_hat.shape + y_hat = y_hat.reshape(B, T * self.latent_stride, D // self.latent_stride) + y_hat = self.denormalize_z(y_hat) + return y_hat + + def forward( + self, + z: torch.Tensor, + t: torch.Tensor, + drop_cond_rate: float = 0.0, + drop_vae_rate: float = 0.0, + sigma: float = 1e-4, + ): + z = self.normalize_z(z) + B, T, D = z.shape + if self.latent_stride > 1: + z = z.reshape(B, T // self.latent_stride, D * self.latent_stride) + + prior, ae_loss = self.autoencoder(z, drop_vae_rate=drop_vae_rate) + + if drop_cond_rate > 0.0: + to_drop = torch.rand(prior.shape[0], device=prior.device) < drop_cond_rate + prior[to_drop] = 0.0 + + x0 = torch.randn_like(z) + x1 = z + + flow_target = x1 - (1 - sigma) * x0 + + alpha = (1 - (1 - sigma) * t).view(-1, 1, 1) + xt = alpha * x0 + t.view(-1, 1, 1) * x1 + + pred = self.flow_net( + xt.transpose(1, 2), + prior.transpose(1, 2), + t, + ) + + flow_loss = nn.functional.mse_loss(flow_target.transpose(1, 2), pred) + + return flow_loss, ae_loss, prior diff --git a/codec/models/patchvae/modules.py b/codec/models/patchvae/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2f9726a1f97c4b7435af954706c340cb369b567f --- /dev/null +++ b/codec/models/patchvae/modules.py @@ -0,0 +1,396 @@ +from random import random +from typing import Literal + +import torch +import torch.nn.functional as F +from torch import nn +from vector_quantize_pytorch import FSQ + +from zcodec.models.components.transformer import TransformerBlock + + +class AdaLayerNormScale(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.linear = nn.Linear(dim, dim * 3) + self.norm = nn.LayerNorm(dim, elementwise_affine=False) + + def forward(self, x, c): + x = self.norm(x) + scale, bias, gate = self.linear(F.silu(c)).chunk(3, dim=1) + shape = x.shape[0] + [1] * (x.dim() - 2) + x.shape[-1] + scale, bias, gate = map(lambda x: x.view(*shape), (scale, bias, gate)) + x = x * (1 + scale) + bias + return x, gate + + +class GaussianFourierTimeEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.weight = nn.Parameter(torch.randn(dim), requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x[:, None] * self.weight[None, :] * 2 * torch.pi + x = torch.cat((torch.sin(x), torch.cos(x)), dim=1) + return x + + +LAYER_FACTORIES = {} + + +def register_flow_layer_factory(name): + def decorator(fn): + LAYER_FACTORIES[name] = fn + return fn + + return decorator + + +@register_flow_layer_factory("convnext") +def SimpleConvNextFactory(dim: int, i: int, n_layer: int, is_causal: bool = False): + return ConvNeXtBlock(dim, elementwise_affine_ln=False, is_causal=is_causal) + + +@register_flow_layer_factory("mlp") +def MLP(dim: int, i: int, n_layer: int, is_causal: bool = False): + return AdaLNMLP(dim) + + +@register_flow_layer_factory("sa_transformer") +def SelfAttentionTransformer(dim: int, i: int, n_layer: int, is_causal: bool = False): + return TransformerBlock(dim, 64, elementwise_affine_ln=False, is_causal=is_causal) + + +def init_weights(m: nn.Module): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +def init_adaln_weights(m: nn.Module): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.zeros_(m.bias) + + +def modulate(x, scale, shift): + return x * (1 + scale[:, None]) + shift[:, None] + + +class AdaLNFlowPredictor(nn.Module): + def __init__( + self, + feat_dim: int, + dim: int, + n_layer: int, + layer_factory: str, + cond_dim: int | None = None, + is_causal: bool = False, + ): + super().__init__() + + layer_factory = LAYER_FACTORIES[layer_factory] + self.layers = nn.ModuleList( + [ + layer_factory(dim, i, n_layer, is_causal=is_causal) + for i in range(n_layer) + ] + ) + if cond_dim is None: + cond_dim = feat_dim + self.initial_proj = nn.Linear(feat_dim + cond_dim, dim) + self.adaln_proj = nn.ModuleList([nn.Linear(dim, dim * 3) for _ in self.layers]) + self.final_adaln_proj = nn.Linear(dim, dim * 2) + self.out_proj = nn.Linear(dim, feat_dim) + self.final_norm = nn.LayerNorm(dim, elementwise_affine=False) + self.time_emb = GaussianFourierTimeEmbedding(dim // 2) + + self.apply(init_weights) + for l in self.adaln_proj: + init_adaln_weights(l) + init_adaln_weights(self.final_adaln_proj) + + def forward( + self, + x_t: torch.Tensor, + x_mu: torch.Tensor, + t: torch.Tensor, + ): + x_t, x_mu = map(lambda x: x.transpose(1, 2), (x_t, x_mu)) + x = self.initial_proj(torch.cat((x_t, x_mu), dim=-1)).transpose(1, 2) + + t_emb = self.time_emb(t) + + for i, (l, adaln) in enumerate(zip(self.layers, self.adaln_proj)): + scale, shift, gate = F.silu(adaln(t_emb)).chunk(3, dim=1) + x = l(x, scale_shift=(scale, shift), gate=gate) + + scale, shift = F.silu(self.final_adaln_proj(t_emb)).chunk(2, dim=1) + x = self.final_norm(x.transpose(1, 2)) + x = modulate(x, scale, shift) + + x = self.out_proj(x).transpose(1, 2) + + return x + + +class AdaLNMLP(nn.Module): + def __init__(self, hidden_dim): + super().__init__() + self.hidden_dim = hidden_dim + + self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False) + self.mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim, bias=True), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim, bias=True), + ) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_dim, 4 * hidden_dim, bias=True) + ) + + def forward(self, x, scale_shift, gate): + x = x.transpose(-1, -2) + h = modulate(self.in_ln(x), *scale_shift) + h = self.mlp(h) + return (x + gate[:, None] * h).transpose(-1, -2) + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int | None = None, + layer_scale_init_value: float = 0.0, + elementwise_affine_ln: bool = True, + is_causal: bool = False, + ): + super().__init__() + intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=0 if is_causal else 3, groups=dim + ) # depthwise conv + self.norm = nn.LayerNorm( + dim, eps=1e-6, elementwise_affine=elementwise_affine_ln + ) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.is_causal = is_causal + + def forward( + self, + x: torch.Tensor, + scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None, + gate: torch.Tensor | None = None, + ) -> torch.Tensor: + residual = x + if self.is_causal: + x = torch.nn.functional.pad(x, (6, 0)) + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + x = self.norm(x) + if scale_shift is not None: + scale, shift = scale_shift + x = x * scale[:, None] + shift[:, None] + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + if gate is not None: + x = gate[:, None] * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class ConvNextNet(nn.Module): + def __init__( + self, + dim: int, + n_layers: int, + intermediate_dim: int | None = None, + is_causal: bool = False, + ): + super().__init__() + self.net = nn.Sequential( + *[ + ConvNeXtBlock(dim, intermediate_dim, is_causal=is_causal) + for _ in range(n_layers) + ] + ) + + def forward(self, x): + return self.net(x.transpose(1, 2)).transpose(1, 2) + + +def convnext_factory(dim, n_layers, is_causal=False): + return ConvNextNet(dim, n_layers, is_causal=is_causal) + + +def convnextformer_factory( + dim, n_layers, n_convnext_per_transformer_block, is_causal=False +): + layers = [] + for i in range(0, n_layers, n_convnext_per_transformer_block + 1): + layers.append( + ConvNextNet(dim, n_convnext_per_transformer_block, is_causal=is_causal) + ) + layers.append(TransformerBlock(dim, 64, is_causal=is_causal)) + return nn.Sequential(*layers) + + +class AutoEncoder(nn.Module): + def __init__( + self, + feat_dim: int, + hidden_dim: int, + num_layers: int, + net_factory: Literal["convnext", "convnextformer_decoder", "convnextformer"], + out_dim: int | None = None, + convnextformer_num_conv_per_transformer: int = 3, + causal_transformer: bool = False, + bottleneck_size: int | None = None, + vae: bool = False, + is_causal: bool = False, + ): + super().__init__() + + self.embed = nn.Linear(feat_dim, hidden_dim) + if out_dim is None: + out_dim = feat_dim + self.unembed = nn.Linear(hidden_dim, out_dim) + + if net_factory == "convnext": + self.encoder_net = convnext_factory( + hidden_dim, num_layers, is_causal=is_causal + ) + self.decoder_net = convnext_factory( + hidden_dim, num_layers, is_causal=is_causal + ) + elif net_factory == "convnextformer_decoder": + self.encoder_net = convnext_factory( + hidden_dim, num_layers, is_causal=is_causal + ) + self.decoder_net = convnextformer_factory( + hidden_dim, + num_layers, + convnextformer_num_conv_per_transformer, + is_causal=is_causal, + ) + elif net_factory == "convnextformer": + self.encoder_net = convnextformer_factory( + hidden_dim, + num_layers, + convnextformer_num_conv_per_transformer, + is_causal=is_causal, + ) + self.decoder_net = convnextformer_factory( + hidden_dim, + num_layers, + convnextformer_num_conv_per_transformer, + is_causal=is_causal, + ) + + self.bottleneck = ( + nn.Linear(hidden_dim, bottleneck_size * (1 + vae)) + if bottleneck_size is not None + else nn.Identity() + ) + self.unbottleneck = ( + nn.Linear(bottleneck_size, hidden_dim) + if bottleneck_size is not None + else nn.Identity() + ) + self.vae = vae + + def reparameterize( + self, + mu: torch.Tensor, + logvar: torch.Tensor, + deterministic: bool = False, + drop_vae_rate: float = 0.0, + ) -> torch.Tensor: + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + if drop_vae_rate > 0.0: + to_drop = torch.rand(std.shape[0], device=std.device) < drop_vae_rate + eps = torch.randn_like(std) + eps[to_drop] = 0.0 + else: + if deterministic: + eps = torch.zeros_like(std) + else: + eps = torch.randn_like(std) + return mu + eps * std + + def kl_divergence(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) + return kl.sum(dim=-1).mean() + + def forward(self, x: torch.Tensor, drop_vae_rate: float = 0.0) -> torch.Tensor: + # Encode + x = self.embed(x) + x = self.encoder_net(x) + x = self.bottleneck(x) + if self.vae: + mu, logvar = x.chunk(2, dim=-1) + loss = { + "kl_div": self.kl_divergence(mu, logvar), + "_mu_mean": mu.mean(), + "_mu_std": mu.std(), + "_logvar_mean": logvar.mean(), + "_logvar_std": logvar.std(), + } + x = self.reparameterize( + mu, + logvar, + drop_vae_rate=drop_vae_rate, + ) + else: + loss = {} + + # Decode + x = self.unbottleneck(x) + x = self.decoder_net(x) + x = self.unembed(x) + + return x, loss + + def encode(self, x: torch.Tensor, deterministic: bool = False): + x = self.embed(x) + x = self.encoder_net(x) + x = self.bottleneck(x) + + if self.vae: + x = self.reparameterize(*x.chunk(2, dim=-1), deterministic=deterministic) + return x + + def decode( + self, + latent: torch.Tensor | None = None, + ): + x = self.unbottleneck(latent) + x = self.decoder_net(x) + x = self.unembed(x) + return x diff --git a/codec/models/wavvae/__init__.py b/codec/models/wavvae/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/codec/models/wavvae/__pycache__/__init__.cpython-312.pyc b/codec/models/wavvae/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bf1547a21222ae0bee9ea4841c041f1f2b27239 Binary files /dev/null and b/codec/models/wavvae/__pycache__/__init__.cpython-312.pyc differ diff --git a/codec/models/wavvae/__pycache__/discriminators.cpython-312.pyc b/codec/models/wavvae/__pycache__/discriminators.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..914cab529f09300e46c68e097f707acaaaf43015 Binary files /dev/null and b/codec/models/wavvae/__pycache__/discriminators.cpython-312.pyc differ diff --git a/codec/models/wavvae/__pycache__/heads.cpython-312.pyc b/codec/models/wavvae/__pycache__/heads.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2547e2211f7d5a72bc106c8fb715a8ee6e305018 Binary files /dev/null and b/codec/models/wavvae/__pycache__/heads.cpython-312.pyc differ diff --git a/codec/models/wavvae/__pycache__/layers.cpython-312.pyc b/codec/models/wavvae/__pycache__/layers.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca73e82c8c07af3678ca2cda950dd1c28aabab06 Binary files /dev/null and b/codec/models/wavvae/__pycache__/layers.cpython-312.pyc differ diff --git a/codec/models/wavvae/__pycache__/loss.cpython-312.pyc b/codec/models/wavvae/__pycache__/loss.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c507c5fb6931e3c6cfafae3317a4b47cefb3314 Binary files /dev/null and b/codec/models/wavvae/__pycache__/loss.cpython-312.pyc differ diff --git a/codec/models/wavvae/__pycache__/model.cpython-312.pyc b/codec/models/wavvae/__pycache__/model.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef4754059a3498cbc4eddf81c248b952f086bf9c Binary files /dev/null and b/codec/models/wavvae/__pycache__/model.cpython-312.pyc differ diff --git a/codec/models/wavvae/__pycache__/modules.cpython-312.pyc b/codec/models/wavvae/__pycache__/modules.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6638610c2e36aaefc76be800057ba4dd37d94d7c Binary files /dev/null and b/codec/models/wavvae/__pycache__/modules.cpython-312.pyc differ diff --git a/codec/models/wavvae/__pycache__/spectral_ops.cpython-312.pyc b/codec/models/wavvae/__pycache__/spectral_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f25c9a37e688787b91530c42cec3021c1a519a2 Binary files /dev/null and b/codec/models/wavvae/__pycache__/spectral_ops.cpython-312.pyc differ diff --git a/codec/models/wavvae/dataset.py b/codec/models/wavvae/dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..6fab0e0668f1f54027eeebfb16ae22f61c881418 --- /dev/null +++ b/codec/models/wavvae/dataset.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass + +import numpy as np +import torch +import torchaudio +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset + +torch.set_num_threads(1) + + +@dataclass +class WavVAEDataConfig: + filelist_path: str + sampling_rate: int + num_samples: int + batch_size: int + num_workers: int + + +class WavVAEDataModule(LightningDataModule): + def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig): + super().__init__() + self.train_config = train_params + self.val_config = val_params + + def _get_dataloder(self, cfg: DataConfig, train: bool): + dataset = WavVAEDataset(cfg, train=train) + dataloader = DataLoader( + dataset, + batch_size=cfg.batch_size, + num_workers=cfg.num_workers, + shuffle=train, + pin_memory=True, + ) + return dataloader + + def train_dataloader(self) -> DataLoader: + return self._get_dataloder(self.train_config, train=True) + + def val_dataloader(self) -> DataLoader: + return self._get_dataloder(self.val_config, train=False) + + +class WavVAEDataset(Dataset): + def __init__(self, cfg: DataConfig, train: bool): + with open(cfg.filelist_path) as f: + self.filelist = f.read().splitlines() + self.sampling_rate = cfg.sampling_rate + self.num_samples = cfg.num_samples + self.train = train + + def __len__(self) -> int: + return len(self.filelist) + + def __getitem__(self, index: int) -> torch.Tensor: + audio_path = self.filelist[index] + y, sr = torchaudio.load(audio_path) + if y.size(0) > 1: + # mix to mono + y = y.mean(dim=0, keepdim=True) + gain = np.random.uniform(-1, -6) if self.train else -3 + y, _ = torchaudio.sox_effects.apply_effects_tensor( + y, sr, [["norm", f"{gain:.2f}"]] + ) + try: + if sr != self.sampling_rate: + y = torchaudio.functional.resample( + y, orig_freq=sr, new_freq=self.sampling_rate + ) + except: + print(audio_path, y.shape) + if y.size(-1) < self.num_samples: + pad_length = self.num_samples - y.size(-1) + padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) + y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) + elif self.train: + start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) + y = y[:, start : start + self.num_samples] + else: + # During validation, take always the first segment for determinism + y = y[:, : self.num_samples] + + return y[0] diff --git a/codec/models/wavvae/discriminators.py b/codec/models/wavvae/discriminators.py new file mode 100755 index 0000000000000000000000000000000000000000..2c62574a78bc963eb35dedffd6ef13a96f4e4193 --- /dev/null +++ b/codec/models/wavvae/discriminators.py @@ -0,0 +1,211 @@ +from typing import List, Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torch.nn import Conv2d +from torch.nn.utils import weight_norm +from torchaudio.transforms import Spectrogram + + +class MultiPeriodDiscriminator(nn.Module): + """ + Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + periods (tuple[int]): Tuple of periods for each discriminator. + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None): + super().__init__() + self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods]) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorP(nn.Module): + def __init__( + self, + period: int, + in_channels: int = 1, + kernel_size: int = 5, + stride: int = 3, + lrelu_slope: float = 0.1, + num_embeddings: Optional[int] = None, + ): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), + weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))), + ] + ) + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + self.lrelu_slope = lrelu_slope + + def forward( + self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + x = x.unsqueeze(1) + fmap = [] + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = torch.nn.functional.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for i, l in enumerate(self.convs): + x = l(x) + x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) + if i > 0: + fmap.append(x) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiResolutionDiscriminator(nn.Module): + def __init__( + self, + fft_sizes: Tuple[int, ...] = (2048, 1024, 512), + num_embeddings: Optional[int] = None, + ): + """ + Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + super().__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__( + self, + window_length: int, + num_embeddings: Optional[int] = None, + channels: int = 32, + hop_factor: float = 0.25, + bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)), + ): + super().__init__() + self.window_length = window_length + self.hop_factor = hop_factor + self.spec_fn = Spectrogram( + n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None + ) + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + convs = lambda: nn.ModuleList( + [ + weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), + weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + + if num_embeddings is not None: + self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) + + def spectrogram(self, x): + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + x = self.spec_fn(x) + x = torch.view_as_real(x) + x = rearrange(x, "b f t c -> b c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): + x_bands = self.spectrogram(x) + fmap = [] + x = [] + for band, stack in zip(x_bands, self.band_convs): + for i, layer in enumerate(stack): + band = layer(band) + band = torch.nn.functional.leaky_relu(band, 0.1) + if i > 0: + fmap.append(band) + x.append(band) + x = torch.cat(x, dim=-1) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + + return x, fmap diff --git a/codec/models/wavvae/experiment.py b/codec/models/wavvae/experiment.py new file mode 100755 index 0000000000000000000000000000000000000000..b28b04f643122b019e912540f228c8ed20be9eeb --- /dev/null +++ b/codec/models/wavvae/experiment.py @@ -0,0 +1,3 @@ + + + diff --git a/codec/models/wavvae/heads.py b/codec/models/wavvae/heads.py new file mode 100755 index 0000000000000000000000000000000000000000..92c576aa3f8a72f2b2b148d2d162eab6e838ac31 --- /dev/null +++ b/codec/models/wavvae/heads.py @@ -0,0 +1,194 @@ +from typing import Optional + +import torch +from einops import rearrange +from torch import nn +from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz + +from .modules import symexp +from .spectral_ops import IMDCT, ISTFT + + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class LinearNoBiasHead(FourierHead): + def __init__(self, dim: int, hop_length: int, n_fft: int): + super().__init__() + self.pre_head = nn.Linear(dim, n_fft + 2) + self.head = nn.Linear(n_fft + 2, hop_length, bias=False) + + def forward(self, x): + y = self.pre_head(x) + y = self.head(y).clamp(min=-1.0, max=1.0) + B, _, _ = y.shape + return y.reshape(B, -1) + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.hop_length = hop_length + self.istft = ISTFT( + n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip( + mag, max=1e2 + ) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + S = mag * (x + 1j * y) + audio = self.istft(S) + audio = nn.functional.pad(audio, (self.hop_length // 2, self.hop_length // 2)) + return audio + + +class IMDCTSymExpHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with symmetric exponential function + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized + based on perceptual scaling. Defaults to None. + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, + dim: int, + mdct_frame_len: int, + padding: str = "same", + sample_rate: Optional[int] = None, + clip_audio: bool = False, + ): + super().__init__() + out_dim = mdct_frame_len // 2 + self.out = nn.Linear(dim, out_dim) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + self.clip_audio = clip_audio + + if sample_rate is not None: + # optionally init the last layer following mel-scale + m_max = _hz_to_mel(sample_rate // 2) + m_pts = torch.linspace(0, m_max, out_dim) + f_pts = _mel_to_hz(m_pts) + scale = 1 - (f_pts / f_pts.max()) + + with torch.no_grad(): + self.out.weight.mul_(scale.view(-1, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTSymExpHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + x = symexp(x) + x = torch.clip( + x, min=-1e2, max=1e2 + ) # safeguard to prevent excessively large magnitudes + audio = self.imdct(x) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + + return audio + + +class IMDCTCosHead(FourierHead): + """ + IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p) + + Args: + dim (int): Hidden dimension of the model. + mdct_frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. + """ + + def __init__( + self, + dim: int, + mdct_frame_len: int, + padding: str = "same", + clip_audio: bool = False, + ): + super().__init__() + self.clip_audio = clip_audio + self.out = nn.Linear(dim, mdct_frame_len) + self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the IMDCTCosHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x = self.out(x) + m, p = x.chunk(2, dim=2) + m = torch.exp(m).clip( + max=1e2 + ) # safeguard to prevent excessively large magnitudes + audio = self.imdct(m * torch.cos(p)) + if self.clip_audio: + audio = torch.clip(x, min=-1.0, max=1.0) + return audio diff --git a/codec/models/wavvae/helpers.py b/codec/models/wavvae/helpers.py new file mode 100755 index 0000000000000000000000000000000000000000..3d303010352ad59dde2996605f124128ee17db36 --- /dev/null +++ b/codec/models/wavvae/helpers.py @@ -0,0 +1,71 @@ +import matplotlib +import numpy as np +import torch +from matplotlib import pyplot as plt +from pytorch_lightning import Callback + +matplotlib.use("Agg") + + +def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray: + """ + Save a matplotlib figure to a numpy array. + + Args: + fig (Figure): Matplotlib figure object. + + Returns: + ndarray: Numpy array representing the figure. + """ + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray: + """ + Plot a spectrogram and convert it to a numpy array. + + Args: + spectrogram (ndarray): Spectrogram data. + + Returns: + ndarray: Numpy array representing the plotted spectrogram. + """ + spectrogram = spectrogram.astype(np.float32) + fig, ax = plt.subplots(figsize=(12, 3)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +class GradNormCallback(Callback): + """ + Callback to log the gradient norm. + """ + + def on_after_backward(self, trainer, model): + model.log("grad_norm", gradient_norm(model)) + + +def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor: + """ + Compute the gradient norm. + + Args: + model (Module): PyTorch model. + norm_type (float, optional): Type of the norm. Defaults to 2.0. + + Returns: + Tensor: Gradient norm. + """ + grads = [p.grad for p in model.parameters() if p.grad is not None] + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type) + return total_norm diff --git a/codec/models/wavvae/layers.py b/codec/models/wavvae/layers.py new file mode 100755 index 0000000000000000000000000000000000000000..686bcdaa25fcc9f41a42aac6dfcb769cd1517161 --- /dev/null +++ b/codec/models/wavvae/layers.py @@ -0,0 +1,282 @@ +import math + +import torch +import torch.nn as nn +from torch.nn.utils.parametrizations import weight_norm + +class VocosDecoder(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + num_layers: int, + ): + super().__init__() + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.convnext = nn.ModuleList( + [ + ConvNeXtBlock( + dim=dim, + intermediate_dim=intermediate_dim, + layer_scale_init_value=0.0, + ) + for _ in range(num_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = x.transpose(1, 2) + for conv_block in self.convnext: + x = conv_block(x) + x = self.final_layer_norm(x.transpose(1, 2)) + return x + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int | None = None, + layer_scale_init_value: float = 0.0, + elementwise_affine_ln: bool = True, + ): + super().__init__() + intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.norm = nn.LayerNorm( + dim, eps=1e-6, elementwise_affine=elementwise_affine_ln + ) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward( + self, + x: torch.Tensor, + scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None, + gate: torch.Tensor | None = None, + ) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + x = self.norm(x) + if scale_shift is not None: + scale, shift = scale_shift + x = x * scale[:, None] + shift[:, None] + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + if gate is not None: + x = gate[:, None] * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class Encoder(nn.Module): + def __init__( + self, + d_model=32, + strides=[2, 4, 4, 8], + depthwise=False, + ): + super().__init__() + layers = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + for stride in strides: + d_model *= 2 + groups = d_model // 2 if depthwise else 1 + layers += [EncoderBlock(output_dim=d_model, stride=stride, groups=groups)] + groups = d_model if depthwise else 1 + layers += [ + WNConv1d(d_model, d_model, kernel_size=7, padding=3, groups=groups), + ] + self.block = nn.Sequential(*layers) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + noise=False, + depthwise=False, + d_out=1, + ): + super().__init__() + if depthwise: + layers = [ + WNConv1d( + input_channel, + input_channel, + kernel_size=7, + padding=3, + groups=input_channel, + ), + WNConv1d(input_channel, channels, kernel_size=1), + ] + else: + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + groups = output_dim if depthwise else 1 + layers.append( + DecoderBlock(input_dim, output_dim, stride, noise, groups=groups) + ) + + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + self.model = nn.Sequential(*layers) + + def forward(self, x): + x = self.model(x) + return x + + +class ResidualUnit(nn.Module): + def __init__(self, dim=16, dilation=1, kernel=7, groups=1): + super().__init__() + pad = ((kernel - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d( + dim, + dim, + kernel_size=kernel, + dilation=dilation, + padding=pad, + groups=groups, + ), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, output_dim=16, input_dim=None, stride=1, groups=1): + super().__init__() + input_dim = input_dim or output_dim // 2 + self.block = nn.Sequential( + ResidualUnit(input_dim, dilation=1, groups=groups), + ResidualUnit(input_dim, dilation=3, groups=groups), + ResidualUnit(input_dim, dilation=9, groups=groups), + Snake1d(input_dim), + WNConv1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class NoiseBlock(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear = WNConv1d(dim, dim, kernel_size=1, bias=False) + + def forward(self, x): + B, C, T = x.shape + noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype) + h = self.linear(x) + n = noise * h + x = x + n + return x + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim=16, output_dim=8, stride=1, noise=False, groups=1): + super().__init__() + layers = [ + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=stride % 2, + ), + ] + if noise: + layers.append(NoiseBlock(output_dim)) + layers.extend( + [ + ResidualUnit(output_dim, dilation=1, groups=groups), + ResidualUnit(output_dim, dilation=3, groups=groups), + ResidualUnit(output_dim, dilation=9, groups=groups), + ] + ) + self.block = nn.Sequential(*layers) + + def forward(self, x): + return self.block(x) + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/codec/models/wavvae/loss.py b/codec/models/wavvae/loss.py new file mode 100755 index 0000000000000000000000000000000000000000..ec3268216302151bf20351116ad7c3d886a8aeac --- /dev/null +++ b/codec/models/wavvae/loss.py @@ -0,0 +1,142 @@ +from typing import List, Optional, Tuple + +import torch +import torchaudio +from torch import nn + +from .modules import safe_log + + +class MelSpecReconstructionLoss(nn.Module): + """ + L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample + """ + + def __init__( + self, + sample_rate: int = 24000, + n_fft: int | None = None, + hop_length: int = 256, + n_mels: int = 100, + f_min: int = 0, + f_max: Optional[int] = None, + ): + super().__init__() + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=hop_length * 4 if n_fft is None else n_fft, + hop_length=hop_length, + n_mels=n_mels, + center=True, + power=1, + f_min=f_min, + f_max=f_max, + ) + + def forward(self, y_hat, y) -> torch.Tensor: + """ + Args: + y_hat (Tensor): Predicted audio waveform. + y (Tensor): Ground truth audio waveform. + + Returns: + Tensor: L1 loss between the mel-scaled magnitude spectrograms. + """ + # B, C, Th = y_hat.shape + # B, C, T = y.shape + # crop = (Th - T) // 2 + mel_hat = safe_log(self.mel_spec(y_hat)) + # mel_hat = safe_log(self.mel_spec(y_hat[..., crop:-crop])) + # mel = safe_log(self.mel_spec(y[..., crop:-crop])) + mel = safe_log(self.mel_spec(y)) + + loss = torch.nn.functional.l1_loss(mel, mel_hat) + + return loss + + +class GeneratorLoss(nn.Module): + """ + Generator Loss module. Calculates the loss for the generator based on discriminator outputs. + """ + + def forward( + self, disc_outputs: List[torch.Tensor] + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + disc_outputs (List[Tensor]): List of discriminator outputs. + + Returns: + Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from + the sub-discriminators + """ + loss = torch.zeros( + 1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype + ) + gen_losses = [] + for dg in disc_outputs: + l = torch.mean(torch.clamp(1 - dg, min=0)) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +class DiscriminatorLoss(nn.Module): + """ + Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs. + """ + + def forward( + self, + disc_real_outputs: List[torch.Tensor], + disc_generated_outputs: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + """ + Args: + disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples. + disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples. + + Returns: + Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from + the sub-discriminators for real outputs, and a list of + loss values for generated outputs. + """ + loss = torch.zeros( + 1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype + ) + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(torch.clamp(1 - dr, min=0)) + g_loss = torch.mean(torch.clamp(1 + dg, min=0)) + loss += r_loss + g_loss + r_losses.append(r_loss) + g_losses.append(g_loss) + + return loss, r_losses, g_losses + + +class FeatureMatchingLoss(nn.Module): + """ + Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators. + """ + + def forward( + self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]] + ) -> torch.Tensor: + """ + Args: + fmap_r (List[List[Tensor]]): List of feature maps from real samples. + fmap_g (List[List[Tensor]]): List of feature maps from generated samples. + + Returns: + Tensor: The calculated feature matching loss. + """ + loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype) + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss diff --git a/codec/models/wavvae/model.py b/codec/models/wavvae/model.py new file mode 100755 index 0000000000000000000000000000000000000000..b5562a30907bd82d36deb388587024e9ca94a2db --- /dev/null +++ b/codec/models/wavvae/model.py @@ -0,0 +1,140 @@ +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Literal + +import torch +from safetensors.torch import load_file +from torch import nn + +from .heads import ISTFTHead, LinearNoBiasHead +from .layers import Encoder, VocosDecoder +from .modules import ConvNeXtBlock + + +@dataclass +class WavVAEConfig: + conv_dim: int = 48 + latent_dim: int = 32 + decoder_hidden_dim: int = 768 + decoder_intermediate_dim: int = 1536 + decoder_num_layers: int = 8 + n_fft: int = 1024 + hop_length: int = 256 + padding: str = "center" + head_type: Literal["istft", "linear"] = "istft" + strides: list[int] = field(default_factory=lambda: [2, 4, 4, 8]) + learnable_pre_norm: bool = False + sampling_rate: int = 24000 + + +class WavVAE(nn.Module): + def __init__(self, cfg: WavVAEConfig): + super().__init__() + self.conv_encoder = Encoder(cfg.conv_dim, strides=cfg.strides, depthwise=True) + conv_final_dim = cfg.conv_dim * 2 ** len(cfg.strides) + self.bottleneck = nn.Linear(conv_final_dim, cfg.latent_dim * 2) + self.unbottleneck = nn.Linear(cfg.latent_dim, cfg.decoder_hidden_dim) + self.latent_norm = nn.LayerNorm(conv_final_dim) + self.vocos_decoder = VocosDecoder( + cfg.decoder_hidden_dim, + cfg.decoder_intermediate_dim, + cfg.decoder_num_layers, + ) + if cfg.head_type == "istft": + self.head = ISTFTHead( + cfg.decoder_hidden_dim, + cfg.n_fft, + cfg.hop_length, + padding=cfg.padding, + ) + elif cfg.head_type == "linear": + self.head = LinearNoBiasHead( + cfg.decoder_hidden_dim, + cfg.hop_length, + cfg.n_fft, + ) + + self._sampling_rate = cfg.sampling_rate + self._strides = cfg.strides + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @property + def sampling_rate(self) -> int: + return self._sampling_rate + + @property + def hop_length(self) -> int: + hop_length = 1 + for s in self._strides: + hop_length *= s + return hop_length + + @property + def frame_rate(self) -> float: + return self.sampling_rate / self.hop_length + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + device: str = "cpu", + ): + if Path(pretrained_model_name_or_path).exists(): + path = pretrained_model_name_or_path + else: + from huggingface_hub import snapshot_download + + path = snapshot_download(pretrained_model_name_or_path) + + with open(Path(path) / "config.json", "r") as f: + config = json.load(f) + config = WavVAEConfig(**config) + model = cls(config) + state_dict = load_file( + Path(path) / "model.st", + device=device, + ) + + model.load_state_dict(state_dict, assign=True) + return model + + def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def kl_divergence(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) + return kl.sum(dim=-1).mean() + + def encode(self, audio: torch.Tensor) -> torch.Tensor: + y = self.conv_encoder(audio.unsqueeze(1)).transpose(1, 2) + y = self.latent_norm(y) + mu, logvar = self.bottleneck(y).chunk(2, dim=-1) + z = self.reparameterize(mu, logvar) + return z + + def decode(self, z: torch.Tensor) -> torch.Tensor: + y = self.unbottleneck(z) + y = self.vocos_decoder(y) + return self.head(y) + + def forward(self, audio_input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + y = self.conv_encoder(audio_input.unsqueeze(1)).transpose(1, 2) + y = self.latent_norm(y) + mu, logvar = self.bottleneck(y).chunk(2, dim=-1) + kl_div = self.kl_divergence(mu, logvar) + z = self.reparameterize(mu, logvar) + y = self.unbottleneck(z) + y = self.vocos_decoder(y) + audio_output = self.head(y) + + return audio_output, kl_div diff --git a/codec/models/wavvae/modules.py b/codec/models/wavvae/modules.py new file mode 100755 index 0000000000000000000000000000000000000000..af1d6db16e2f10cc9af7bcc64434e2c983d756b7 --- /dev/null +++ b/codec/models/wavvae/modules.py @@ -0,0 +1,213 @@ +from typing import Optional, Tuple + +import torch +from torch import nn +from torch.nn.utils import weight_norm, remove_weight_norm + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: float, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: Optional[float] = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + ] + ) + + self.gamma = nn.ParameterList( + [ + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) diff --git a/codec/models/wavvae/spectral_ops.py b/codec/models/wavvae/spectral_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..a8eda1c8e18a32406aad40415f6a8bf60eb15fea --- /dev/null +++ b/codec/models/wavvae/spectral_ops.py @@ -0,0 +1,192 @@ +import numpy as np +import scipy +import torch +from torch import nn, view_as_real, view_as_complex + + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + assert (window_envelope > 1e-11).all() + y = y / window_envelope + + return y + + +class MDCT(nn.Module): + """ + Modified Discrete Cosine Transform (MDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) + post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) + # view_as_real: NCCL Backend does not support ComplexFloat data type + # https://github.com/pytorch/pytorch/issues/71613 + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + """ + Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. + + Args: + audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size + and T is the length of the audio. + + Returns: + Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames + and N is the number of frequency bins. + """ + if self.padding == "center": + audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2)) + elif self.padding == "same": + # hop_length is 1/2 frame_len + audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4)) + else: + raise ValueError("Padding must be 'center' or 'same'.") + + x = audio.unfold(-1, self.frame_len, self.frame_len // 2) + N = self.frame_len // 2 + x = x * self.window.expand(x.shape) + X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N] + res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) + return torch.real(res) * np.sqrt(2) + + +class IMDCT(nn.Module): + """ + Inverse Modified Discrete Cosine Transform (IMDCT) module. + + Args: + frame_len (int): Length of the MDCT frame. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, frame_len: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.frame_len = frame_len + N = frame_len // 2 + n0 = (N + 1) / 2 + window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() + self.register_buffer("window", window) + + pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) + post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) + self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) + self.register_buffer("post_twiddle", view_as_real(post_twiddle)) + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. + + Args: + X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, + L is the number of frames, and N is the number of frequency bins. + + Returns: + Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. + """ + B, L, N = X.shape + Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) + Y[..., :N] = X + Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) + y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1) + y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2) + result = y * self.window.expand(y.shape) + output_size = (1, (L + 1) * N) + audio = torch.nn.functional.fold( + result.transpose(1, 2), + output_size=output_size, + kernel_size=(1, self.frame_len), + stride=(1, self.frame_len // 2), + )[:, 0, 0, :] + + if self.padding == "center": + pad = self.frame_len // 2 + elif self.padding == "same": + pad = self.frame_len // 4 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + audio = audio[:, pad:-pad] + return audio diff --git a/codec/scripts/compare_codecs.py b/codec/scripts/compare_codecs.py new file mode 100644 index 0000000000000000000000000000000000000000..fbea2f8b7d4356897318ef2da5fa0cbf8c2b6a70 --- /dev/null +++ b/codec/scripts/compare_codecs.py @@ -0,0 +1,441 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Any, Tuple + +import torch +from torchaudio import load as ta_load +from torchaudio.functional import resample as ta_resample +import torchaudio + +# Your libs +from zcodec.models import WavVAE, ZFlowAutoEncoder + + +# ------------------------- +# Data structures +# ------------------------- + + +@dataclass +class DecodeParams: + num_steps: int = 10 + cfg: float = 2.0 + + +@dataclass +class ModelPairSpec: + name: str + wavvae_dir: str + zflowae_dir: str + decode: DecodeParams + + +# ------------------------- +# Utilities +# ------------------------- + + +def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]: + if path.is_file(): + try: + with path.open("r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return None + return None + + +def read_config_any(checkpoint_dir: str) -> Dict[str, Any]: + """ + Try to read config.json (or a few common fallbacks) from a checkpoint dir. + Returns {} if nothing could be parsed. + """ + cand = [ + Path(checkpoint_dir) / "config.json", + Path(checkpoint_dir) + / "config.yaml", # won't parse yaml here, we only display path + Path(checkpoint_dir) / "model_config.json", + ] + for p in cand: + if p.exists(): + if p.suffix == ".json": + j = load_json_if_exists(p) + if j is not None: + return j + else: + # For YAML or unknown, just show filename rather than failing + return {"_config_file": str(p)} + return {} + + +def sanitize_name(s: str) -> str: + return "".join(c if c.isalnum() or c in "-_." else "_" for c in s) + + +def ensure_mono_and_resample( + wav: torch.Tensor, sr: int, target_sr: int +) -> Tuple[torch.Tensor, int]: + """ + wav: (channels, samples) + returns mono float32 in [-1,1], resampled to target_sr + """ + if wav.ndim != 2: + raise ValueError(f"Expected 2D waveform (C, T), got shape {tuple(wav.shape)}") + # to mono + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + # resample if needed + if sr != target_sr: + wav = ta_resample(wav, sr, target_sr) + sr = target_sr + return wav.to(torch.float32), sr + + +def save_wav(path: Path, wav: torch.Tensor, sr: int): + path.parent.mkdir(parents=True, exist_ok=True) + # (C, T) + if wav.ndim == 1: + wav = wav.unsqueeze(0) + # Clamp to [-1,1] + wav = wav.clamp(-1, 1).contiguous().cpu() + torchaudio.save( + str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16 + ) + + +# ------------------------- +# Core inference +# ------------------------- + + +@torch.inference_mode() +def reconstruct_full_pipeline( + wav_mono: torch.Tensor, + sr: int, + wavvae: WavVAE, + zflowae: ZFlowAutoEncoder, + decode_params: DecodeParams, + device: str, +) -> torch.Tensor: + """ + Full path: audio -> WavVAE.encode -> ZFlowAE.encode -> ZFlowAE.decode -> WavVAE.decode -> audio_hat + """ + wav_mono = wav_mono.to(device) + # WavVAE expects (B, C, T); assume C=1 + x = wav_mono.unsqueeze(0) # (1, 1, T) + # Encode to high-framerate latents + z = wavvae.encode(x) + # Compress latents + y = zflowae.encode(z) + # Decompress + z_hat = zflowae.decode(y, num_steps=decode_params.num_steps, cfg=decode_params.cfg) + # Decode to waveform + wav_hat = wavvae.decode(z_hat) # (1, 1, T) + # Return mono 1D + return wav_hat.squeeze(0).squeeze(0).detach() + + +def load_model_pair(spec: ModelPairSpec, device: str): + wavvae = WavVAE.from_pretrained_local(spec.wavvae_dir).to(device) + zflowae = ZFlowAutoEncoder.from_pretrained_local(spec.zflowae_dir).to(device) + # try to get sampling rate from WavVAE + target_sr = getattr(wavvae, "sampling_rate", None) + if target_sr is None: + # reasonable fallback + target_sr = 24000 + return wavvae, zflowae, int(target_sr) + + +def parse_manifest(path: str) -> List[ModelPairSpec]: + """ + Manifest format (JSON list): + [ + { + "name": "zdim32x8", + "wavvae": "/path/to/WavVAE_framerate100_zdim32/", + "zflowae": "/path/to/ZFlowAutoEncoder_stride4_zdim32_vae8_.../", + "decode": {"num_steps": 10, "cfg": 2.0} + } + ] + """ + with open(path, "r", encoding="utf-8") as f: + raw = json.load(f) + out: List[ModelPairSpec] = [] + for item in raw: + name = item["name"] + wavvae_dir = item["wavvae"] + zflowae_dir = item["zflowae"] + d = item.get("decode", {}) + out.append( + ModelPairSpec( + name=name, + wavvae_dir=wavvae_dir, + zflowae_dir=zflowae_dir, + decode=DecodeParams( + num_steps=int(d.get("num_steps", 10)), + cfg=float(d.get("cfg", 2.0)), + ), + ) + ) + return out + + +# ------------------------- +# HTML generation +# ------------------------- + + +def html_escape(s: str) -> str: + return ( + s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + +def make_html( + output_dir: Path, + audio_files: List[Path], + models: List[ModelPairSpec], + sr_by_model: Dict[str, int], + wavvae_cfg: Dict[str, Dict[str, Any]], + zflow_cfg: Dict[str, Dict[str, Any]], +) -> str: + """ + Build a static HTML page with a table: + Row = input audio file + Col 1 = Original + Col 2..N = each model reconstruction + Also shows minimal model config info above the table. + """ + + def player(src_rel: str, controls: bool = True) -> str: + return f'' + + # Model cards + model_cards = [] + for spec in models: + wcfg = wavvae_cfg.get(spec.name, {}) + zcfg = zflow_cfg.get(spec.name, {}) + w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[ + :1200 + ] + z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[ + :1200 + ] + card = f""" +
+

{html_escape(spec.name)}

+

Sample rate: {sr_by_model.get(spec.name, "N/A")} Hz

+
+ WavVAE config +
{html_escape(w_short)}
+
+
+ ZFlowAE config +
{html_escape(z_short)}
+
+

Decode: num_steps={spec.decode.num_steps}, cfg={spec.decode.cfg}

+
+ """ + model_cards.append(card) + + # Table header + th = "InputOriginal" + "".join( + f"{html_escape(m.name)}" for m in models + ) + + # Rows + rows = [] + for af in audio_files: + base = af.stem + orig_rel = f"original/{html_escape(af.name)}" + tds = [f"{html_escape(base)}", f"{player(orig_rel)}"] + for m in models: + rec_rel = f"recon/{html_escape(m.name)}/{html_escape(base)}.wav" + tds.append(f"{player(rec_rel)}") + rows.append("" + "".join(tds) + "") + + # Simple CSS to keep it clean + css = """ + body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; } + h1 { margin-bottom: 0.2rem; } + .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; } + .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; } + table { border-collapse: collapse; width: 100%; } + th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; } + th { background: #fafafa; position: sticky; top: 0; } + audio { width: 260px; } + """ + + html = f""" + + + + Codec Comparison + + + +

Codec Comparison

+

This page compares reconstructions across model checkpoints. Click play in each cell.

+ +

Models

+
+ {"".join(model_cards)} +
+ +

Audio

+ + {th} + + {"".join(rows)} + +
+ + +""" + out = output_dir / "index.html" + out.write_text(html, encoding="utf-8") + return str(out) + + +# ------------------------- +# Main +# ------------------------- + + +def main(): + p = argparse.ArgumentParser( + description="Compare Z-Codec configurations and generate a static HTML page." + ) + p.add_argument( + "--manifest", + type=str, + required=True, + help="JSON file listing model pairs. See docstring in parse_manifest().", + ) + p.add_argument( + "--audio", type=str, nargs="+", required=True, help="List of input audio files." + ) + p.add_argument( + "--out", + type=str, + default="codec_compare_out", + help="Output directory for reconstructions and HTML.", + ) + p.add_argument( + "--device", + type=str, + default="cuda", + help="Device to run inference on (cuda or cpu).", + ) + p.add_argument( + "--force", + action="store_true", + help="Recompute even if target wav already exists.", + ) + args = p.parse_args() + + device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu" + out_dir = Path(args.out) + orig_dir = out_dir / "original" + recon_dir = out_dir / "recon" + orig_dir.mkdir(parents=True, exist_ok=True) + recon_dir.mkdir(parents=True, exist_ok=True) + + # Parse models + specs = parse_manifest(args.manifest) + if not specs: + print("No models in manifest.", file=sys.stderr) + sys.exit(1) + + # Load models + loaded: Dict[str, Dict[str, Any]] = {} + sr_by_model: Dict[str, int] = {} + wavvae_cfg: Dict[str, Dict[str, Any]] = {} + zflow_cfg: Dict[str, Dict[str, Any]] = {} + + for spec in specs: + print(f"[Load] {spec.name}") + wavvae, zflowae, target_sr = load_model_pair(spec, device) + loaded[spec.name] = {"wavvae": wavvae, "zflowae": zflowae, "sr": target_sr} + sr_by_model[spec.name] = target_sr + wavvae_cfg[spec.name] = read_config_any(spec.wavvae_dir) + zflow_cfg[spec.name] = read_config_any(spec.zflowae_dir) + + # Process audio files + audio_files = [Path(a) for a in args.audio] + for af in audio_files: + if not af.exists(): + print(f"[Skip] Missing: {af}", file=sys.stderr) + continue + + # copy original (resampled per model? We'll store original as-is) + # Just place the original file for direct playback + # If it's not wav, we still copy a WAV version for compatibility. + # But simplest: if not wav, we re-save as wav 16-bit for the page. + out_orig = orig_dir / af.name + if args.force or not out_orig.exists(): + # Load and resave as wav to ensure browser-compat + wav, sr = ta_load(str(af)) + # make it mono for fair listening + wav_mono, sr = ensure_mono_and_resample(wav, sr, sr) + save_wav(out_orig.with_suffix(".wav"), wav_mono, sr) + # keep the name consistent in the HTML (use .wav) + af = af.with_suffix(".wav") + # rename saved file to matched name + if out_orig.suffix != ".wav": + # Clean: ensure HTML references the .wav filename + out_orig = out_orig.with_suffix(".wav") + + # For each model, run full pipeline and save + base = af.stem + # Re-load from disk to ensure consistent start-point (original .wav in out folder) + wav0, sr0 = ta_load(str(out_orig if out_orig.exists() else orig_dir / af.name)) + # Make mono only once; resample per-model to each target SR + if wav0.size(0) > 1: + wav0 = wav0.mean(dim=0, keepdim=True) + + for spec in specs: + mname = spec.name + target_sr = sr_by_model[mname] + # resample to model's SR + if sr0 != target_sr: + wav_mono = ta_resample(wav0, sr0, target_sr) + else: + wav_mono = wav0 + + # reconstruct + out_path = recon_dir / mname / f"{sanitize_name(base)}.wav" + if args.force or not out_path.exists(): + print(f"[Reconstruct] {mname} ← {base}") + wavvae = loaded[mname]["wavvae"] + zflowae = loaded[mname]["zflowae"] + wav_hat = reconstruct_full_pipeline( + wav_mono, target_sr, wavvae, zflowae, spec.decode, device + ) + save_wav(out_path, wav_hat.unsqueeze(0), target_sr) + + # Build HTML + # Rebuild the list of files actually present in original/ (use .wav names) + actual_audio = sorted([p for p in (orig_dir).glob("*.wav")]) + html_path = make_html( + out_dir, + actual_audio, + specs, + sr_by_model, + wavvae_cfg, + zflow_cfg, + ) + print(f"\nDone. Open: {html_path}") + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/compare_wavvae.py b/codec/scripts/compare_wavvae.py new file mode 100644 index 0000000000000000000000000000000000000000..c8af4753ca1ed4f3c58575a7c5bbe4a35aa4aaf4 --- /dev/null +++ b/codec/scripts/compare_wavvae.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +import argparse +import json +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torchaudio +from torchaudio import load as ta_load +from torchaudio.functional import resample as ta_resample +from zcodec.models import WavVAE + +# ------------------------- +# Data structures +# ------------------------- + + +@dataclass +class WavVaeSpec: + name: str + wavvae_dir: str + + +# ------------------------- +# Utilities +# ------------------------- + + +def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]: + if path.is_file(): + try: + return json.load(path.open("r", encoding="utf-8")) + except Exception: + return None + return None + + +def read_config_any(checkpoint_dir: str) -> Dict[str, Any]: + cand = [ + Path(checkpoint_dir) / "config.json", + Path(checkpoint_dir) / "model_config.json", + Path(checkpoint_dir) / "config.yaml", # shown as path only + ] + for p in cand: + if p.exists(): + if p.suffix == ".json": + j = load_json_if_exists(p) + if j is not None: + return j + else: + return {"_config_file": str(p)} + return {} + + +def sanitize_name(s: str) -> str: + return "".join(c if c.isalnum() or c in "-_." else "_" for c in s) + + +def ensure_mono_and_resample( + wav: torch.Tensor, sr: int, target_sr: int +) -> Tuple[torch.Tensor, int]: + if wav.ndim != 2: + raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}") + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + if sr != target_sr: + wav = ta_resample(wav, sr, target_sr) + sr = target_sr + return wav.to(torch.float32), sr + + +def save_wav(path: Path, wav: torch.Tensor, sr: int): + path.parent.mkdir(parents=True, exist_ok=True) + if wav.ndim == 1: + wav = wav.unsqueeze(0) + wav = wav.clamp(-1, 1).contiguous().cpu() + torchaudio.save( + str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16 + ) + + +def read_audio_manifest(txt_path: str) -> List[Path]: + lines = Path(txt_path).read_text(encoding="utf-8").splitlines() + files = [ + Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#") + ] + return files + + +def html_escape(s: str) -> str: + return ( + s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + +def make_html( + output_dir: Path, + audio_files: List[Path], + specs: List[WavVaeSpec], + sr_by_model: Dict[str, int], + wavvae_cfg: Dict[str, Dict[str, Any]], +) -> str: + def player(src_rel: str) -> str: + return f'' + + # cards + cards = [] + for s in specs: + cfg = wavvae_cfg.get(s.name, {}) + cfg_short = json.dumps(cfg if cfg else {"_": "no JSON config found"}, indent=2)[ + :1200 + ] + card = f""" +
+

{html_escape(s.name)}

+

Sample rate: {sr_by_model.get(s.name, "N/A")} Hz

+
WavVAE config
{html_escape(cfg_short)}
+
+ """ + cards.append(card) + + css = """ + body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; } + .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; } + .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; } + table { border-collapse: collapse; width: 100%; } + th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; } + th { background: #fafafa; position: sticky; top: 0; } + audio { width: 260px; } + """ + + th = "InputOriginal" + "".join( + f"{html_escape(s.name)}" for s in specs + ) + rows = [] + for af in audio_files: + base = af.stem + orig_rel = f"original/{html_escape(af.name)}" + tds = [f"{html_escape(base)}", f"{player(orig_rel)}"] + for s in specs: + rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav" + tds.append(f"{player(rec_rel)}") + rows.append("" + "".join(tds) + "") + + html = f""" + + WavVAE Comparison + +

WavVAE Comparison

+
{"".join(cards)}
+ + {th} + {"".join(rows)} +
+ + +""" + out = output_dir / "index.html" + out.write_text(html, encoding="utf-8") + return str(out) + + +# ------------------------- +# Core +# ------------------------- + + +@torch.inference_mode() +def reconstruct_wavvae( + wav_mono: torch.Tensor, wavvae: WavVAE, device: str +) -> torch.Tensor: + x = wav_mono.to(device) # (1,T) + z = wavvae.encode(x) + wav_hat = wavvae.decode(z) # (1,1,T) + return wav_hat.squeeze(0).squeeze(0).detach() + + +def parse_models_manifest(path: str) -> List[WavVaeSpec]: + """ + JSON list of: + {"name": "...", "wavvae": "/path/to/WavVAE_dir"} + """ + raw = json.loads(Path(path).read_text(encoding="utf-8")) + specs = [] + for it in raw: + specs.append(WavVaeSpec(name=it["name"], wavvae_dir=it["wavvae"])) + return specs + + +def main(): + ap = argparse.ArgumentParser( + description="Compare WavVAE checkpoints and generate a static HTML page." + ) + ap.add_argument("--models", required=True, help="JSON manifest of WavVAE models.") + ap.add_argument( + "--audio_manifest", required=True, help="TXT file: one audio path per line." + ) + ap.add_argument("--out", default="compare_wavvae_out") + ap.add_argument("--device", default="cuda") + ap.add_argument("--force", action="store_true") + args = ap.parse_args() + + device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu" + out_dir = Path(args.out) + (out_dir / "original").mkdir(parents=True, exist_ok=True) + recon_dir = out_dir / "recon" + recon_dir.mkdir(parents=True, exist_ok=True) + + specs = parse_models_manifest(args.models) + if not specs: + print("No models.", file=sys.stderr) + sys.exit(1) + + # load models + wavvae_by_name: Dict[str, WavVAE] = {} + sr_by_model: Dict[str, int] = {} + wavvae_cfg: Dict[str, Dict[str, Any]] = {} + for s in specs: + print(f"[Load] {s.name}") + w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device) + wavvae_by_name[s.name] = w + sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000)) + wavvae_cfg[s.name] = read_config_any(s.wavvae_dir) + + audio_paths = read_audio_manifest(args.audio_manifest) + # normalize originals to wav+mono (browser-friendly); keep native sr for original column + actual_audio = [] + for ap in audio_paths: + if not ap.exists(): + print(f"[Skip missing] {ap}", file=sys.stderr) + continue + wav, sr = ta_load(str(ap)) + wav_mono, sr = ensure_mono_and_resample(wav, sr, sr) + out_orig = out_dir / "original" / (ap.stem + ".wav") + if args.force or not out_orig.exists(): + save_wav(out_orig, wav_mono, sr) + actual_audio.append(out_orig) + + # recon per model + for out_orig in actual_audio: + wav0, sr0 = ta_load(str(out_orig)) + if wav0.size(0) > 1: + wav0 = wav0.mean(dim=0, keepdim=True) + for s in specs: + target_sr = sr_by_model[s.name] + wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0 + out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav" + if args.force or not out_path.exists(): + print(f"[Reconstruct] {s.name} ← {out_orig.name}") + wav_hat = reconstruct_wavvae(wav_in, wavvae_by_name[s.name], device) + save_wav(out_path, wav_hat, target_sr) + + html_path = make_html(out_dir, actual_audio, specs, sr_by_model, wavvae_cfg) + print(f"Done. Open: {html_path}") + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/compare_zcodec.py b/codec/scripts/compare_zcodec.py new file mode 100644 index 0000000000000000000000000000000000000000..26042fde831a2ab969d00b870c05a51fec26746c --- /dev/null +++ b/codec/scripts/compare_zcodec.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +import argparse +import json +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torchaudio +from torchaudio import load as ta_load +from torchaudio.functional import resample as ta_resample +from zcodec.models import WavVAE, ZFlowAutoEncoder + +# ------------------------- +# Data structures +# ------------------------- + + +@dataclass +class DecodeParams: + num_steps: int = 10 + cfg: float = 2.0 + + +@dataclass +class StackSpec: + name: str + wavvae_dir: str + zflowae_dir: str + decode: DecodeParams + + +# ------------------------- +# Utilities (same helpers) +# ------------------------- + + +def load_json_if_exists(path: Path): + if path.is_file(): + try: + return json.load(path.open("r", encoding="utf-8")) + except Exception: + return None + return None + + +def read_config_any(checkpoint_dir: str) -> Dict[str, Any]: + cand = [ + Path(checkpoint_dir) / "config.json", + Path(checkpoint_dir) / "model_config.json", + Path(checkpoint_dir) / "config.yaml", + ] + for p in cand: + if p.exists(): + if p.suffix == ".json": + j = load_json_if_exists(p) + if j is not None: + return j + else: + return {"_config_file": str(p)} + return {} + + +def sanitize_name(s: str) -> str: + return "".join(c if c.isalnum() or c in "-_." else "_" for c in s) + + +def ensure_mono_and_resample( + wav: torch.Tensor, sr: int, target_sr: int +) -> Tuple[torch.Tensor, int]: + if wav.ndim != 2: + raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}") + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + if sr != target_sr: + wav = ta_resample(wav, sr, target_sr) + sr = target_sr + return wav.to(torch.float32), sr + + +def save_wav(path: Path, wav: torch.Tensor, sr: int): + path.parent.mkdir(parents=True, exist_ok=True) + if wav.ndim == 1: + wav = wav.unsqueeze(0) + wav = wav.clamp(-1, 1).contiguous().cpu() + torchaudio.save( + str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16 + ) + + +def read_audio_manifest(txt_path: str) -> List[Path]: + lines = Path(txt_path).read_text(encoding="utf-8").splitlines() + return [ + Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#") + ] + + +def html_escape(s: str) -> str: + return ( + s.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) + + +def make_html( + output_dir: Path, + audio_files: List[Path], + specs: List[StackSpec], + sr_by_model: Dict[str, int], + wavvae_cfg: Dict[str, Dict[str, Any]], + zflow_cfg: Dict[str, Dict[str, Any]], +) -> str: + def player(src_rel: str) -> str: + return f'' + + cards = [] + for s in specs: + wcfg = wavvae_cfg.get(s.name, {}) + zcfg = zflow_cfg.get(s.name, {}) + w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[ + :1200 + ] + z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[ + :1200 + ] + card = f""" +
+

{html_escape(s.name)}

+

Sample rate: {sr_by_model.get(s.name, "N/A")} Hz

+

Decode: steps={s.decode.num_steps}, cfg={s.decode.cfg}

+
WavVAE config
{html_escape(w_short)}
+
ZFlowAE config
{html_escape(z_short)}
+
+ """ + cards.append(card) + + css = """ + body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; } + .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; } + .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; } + table { border-collapse: collapse; width: 100%; } + th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; } + th { background: #fafafa; position: sticky; top: 0; } + audio { width: 260px; } + """ + + th = "InputOriginal" + "".join( + f"{html_escape(s.name)}" for s in specs + ) + rows = [] + for af in audio_files: + base = af.stem + orig_rel = f"original/{html_escape(af.name)}" + tds = [f"{html_escape(base)}", f"{player(orig_rel)}"] + for s in specs: + rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav" + tds.append(f"{player(rec_rel)}") + rows.append("" + "".join(tds) + "") + + html = f""" + + Stacked Codec Comparison + +

WavVAE + ZFlowAE Comparison

+
{"".join(cards)}
+ + {th} + {"".join(rows)} +
+ + +""" + out = output_dir / "index.html" + out.write_text(html, encoding="utf-8") + return str(out) + + +# ------------------------- +# Core +# ------------------------- + + +@torch.inference_mode() +def reconstruct_stack( + wav_mono: torch.Tensor, + wavvae: WavVAE, + zflow: ZFlowAutoEncoder, + steps: int, + cfg: float, + device: str, +) -> torch.Tensor: + x = wav_mono.to(device) # (1,T) + z = wavvae.encode(x) # high-framerate latents + y, _ = zflow.encode(z) # compressed latents + z_hat = zflow.decode(y, num_steps=steps, cfg=cfg) + wav_hat = wavvae.decode(z_hat) # (1,1,T) + return wav_hat.squeeze(0).squeeze(0).detach() + + +def parse_models_manifest(path: str) -> List[StackSpec]: + """ + JSON list of: + { + "name": "...", + "wavvae": "/path/to/WavVAE_dir", + "zflowae": "/path/to/ZFlowAE_dir", + "decode": {"num_steps": 10, "cfg": 2.0} + } + """ + raw = json.loads(Path(path).read_text(encoding="utf-8")) + specs = [] + for it in raw: + d = it.get("decode", {}) + specs.append( + StackSpec( + name=it["name"], + wavvae_dir=it["wavvae"], + zflowae_dir=it["zflowae"], + decode=DecodeParams( + num_steps=int(d.get("num_steps", 10)), cfg=float(d.get("cfg", 2.0)) + ), + ) + ) + return specs + + +def main(): + ap = argparse.ArgumentParser( + description="Compare WavVAE+ZFlowAE stacks and generate a static HTML page." + ) + ap.add_argument("--models", required=True, help="JSON manifest of stacks.") + ap.add_argument( + "--audio_manifest", required=True, help="TXT file: one audio path per line." + ) + ap.add_argument("--out", default="compare_stack_out") + ap.add_argument("--device", default="cuda") + ap.add_argument("--force", action="store_true") + args = ap.parse_args() + + device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu" + out_dir = Path(args.out) + (out_dir / "original").mkdir(parents=True, exist_ok=True) + recon_dir = out_dir / "recon" + recon_dir.mkdir(parents=True, exist_ok=True) + + specs = parse_models_manifest(args.models) + if not specs: + print("No models.", file=sys.stderr) + sys.exit(1) + + # load models + wavvae_by_name: Dict[str, WavVAE] = {} + zflow_by_name: Dict[str, ZFlowAutoEncoder] = {} + sr_by_model: Dict[str, int] = {} + wavvae_cfg: Dict[str, Dict[str, Any]] = {} + zflow_cfg: Dict[str, Dict[str, Any]] = {} + for s in specs: + print(f"[Load] {s.name}") + w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device) + z = ZFlowAutoEncoder.from_pretrained_local(s.zflowae_dir).to(device) + wavvae_by_name[s.name] = w + zflow_by_name[s.name] = z + sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000)) + wavvae_cfg[s.name] = read_config_any(s.wavvae_dir) + zflow_cfg[s.name] = read_config_any(s.zflowae_dir) + + audio_paths = read_audio_manifest(args.audio_manifest) + + actual_audio = [] + for ap in audio_paths: + if not ap.exists(): + print(f"[Skip missing] {ap}", file=sys.stderr) + continue + wav, sr = ta_load(str(ap)) + wav_mono, sr = ensure_mono_and_resample(wav, sr, sr) + out_orig = out_dir / "original" / (ap.stem + ".wav") + if args.force or not out_orig.exists(): + save_wav(out_orig, wav_mono, sr) + actual_audio.append(out_orig) + + for out_orig in actual_audio: + wav0, sr0 = ta_load(str(out_orig)) + if wav0.size(0) > 1: + wav0 = wav0.mean(dim=0, keepdim=True) + for s in specs: + target_sr = sr_by_model[s.name] + wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0 + out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav" + if args.force or not out_path.exists(): + print(f"[Reconstruct] {s.name} ← {out_orig.name}") + wav_hat = reconstruct_stack( + wav_in, + wavvae_by_name[s.name], + zflow_by_name[s.name], + s.decode.num_steps, + s.decode.cfg, + device, + ) + save_wav(out_path, wav_hat, target_sr) + + html_path = make_html( + out_dir, actual_audio, specs, sr_by_model, wavvae_cfg, zflow_cfg + ) + print(f"Done. Open: {html_path}") + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/compute_stats.py b/codec/scripts/compute_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..2de69b5ced2435ac0a3d596de617f6ea8f145b82 --- /dev/null +++ b/codec/scripts/compute_stats.py @@ -0,0 +1,76 @@ +import argparse +import random + +import torch +from safetensors.torch import safe_open, save_file +from tqdm import tqdm + + +def load_tensor(path: str, key: str = "embedding") -> torch.Tensor: + with safe_open(path, framework="pt", device="cpu") as f: + return f.get_tensor(key) + + +def compute_global_stats(file_list, key="embedding", length_weighted=True): + sum_all = None + sum_sq_all = None + count_all = 0 + + for path in tqdm(file_list, desc="Computing stats"): + tensor = load_tensor(path, key) # shape: [B, T, D] + flat = tensor.reshape(-1, tensor.shape[-1]) # [B*T, D] + + sum_ = flat.sum(dim=0) # [D] + sum_sq = (flat**2).sum(dim=0) # [D] + count = flat.shape[0] # B*T + + if sum_all is None: + sum_all = sum_ + sum_sq_all = sum_sq + else: + sum_all += sum_ + sum_sq_all += sum_sq + + count_all += count + + mean = sum_all / count_all + var = sum_sq_all / count_all - mean**2 + std = torch.sqrt(torch.clamp(var, min=1e-8)) + + return mean, std + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "filelist", type=str, help="Text file with list of safetensors paths" + ) + parser.add_argument("output", type=str, help="Path to output stats.safetensors") + parser.add_argument( + "--key", type=str, default="audio_z", help="Key of tensor in safetensors file" + ) + parser.add_argument( + "--max-files", type=int, default=None, help="Max number of files to process" + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for shuffling" + ) + + args = parser.parse_args() + + with open(args.filelist) as f: + files = [line.strip() for line in f if line.strip()] + + if args.max_files: + random.seed(args.seed) + files = random.sample(files, k=min(args.max_files, len(files))) + + mean, std = compute_global_stats(files, key=args.key) + + save_file({"mean": mean, "std": std}, args.output) + print(f"✅ Saved to {args.output}") + print("Example mean/std:", mean[:5], std[:5]) + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/compute_wer.py b/codec/scripts/compute_wer.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1f8398d08892c17e461450505979649393bdd8 --- /dev/null +++ b/codec/scripts/compute_wer.py @@ -0,0 +1,48 @@ +import argparse +import json +import string + +from jiwer import wer + + +def normalize_text(text: str) -> str: + """ + Lowercase and remove punctuation from a string. + + Args: + text (str): Input string + + Returns: + str: Normalized string + """ + # Lowercase + text = text.lower() + # Remove punctuation + text = text.translate(str.maketrans("", "", string.punctuation)) + return text + + +def load_transcripts(jsonl_path): + originals = [] + reconstructions = [] + with open(jsonl_path, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + originals.append(data["original_text"]) + reconstructions.append(data["reconstructed_text"]) + return originals, reconstructions + + +def main(args): + originals, reconstructions = map(normalize_text, load_transcripts(args.jsonl)) + score = wer(originals, reconstructions) + print(f"WER: {score:.3%}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--jsonl", type=str, required=True, help="Path to the transcript JSONL file" + ) + args = parser.parse_args() + main(args) diff --git a/codec/scripts/compute_wer_from_refs.py b/codec/scripts/compute_wer_from_refs.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4304a0134a1ce81ab32a7376332f7c260c3efb --- /dev/null +++ b/codec/scripts/compute_wer_from_refs.py @@ -0,0 +1,64 @@ +import argparse +import json +import string +from pathlib import Path + +from jiwer import cer, wer + + +def normalize_text(text: str) -> str: + """ + Lowercase and remove punctuation from a string. + + Args: + text (str): Input string + + Returns: + str: Normalized string + """ + # Lowercase + text = text.lower() + # Remove punctuation + text = text.translate(str.maketrans("", "", string.punctuation)) + return text + + +def load_jsonl_dict(path): + transcripts = {} + with open(path, "r", encoding="utf-8") as f: + for line in f: + data = json.loads(line) + transcripts[Path(data["file"]).name] = data["transcript"] + return transcripts + + +def main(args): + ref_dict = load_jsonl_dict(args.reference) + hyp_dict = load_jsonl_dict(args.hypothesis) + + common_files = set(ref_dict.keys()) & set(hyp_dict.keys()) + + if not common_files: + print("No common files between reference and hypothesis.") + return + + refs = [normalize_text(ref_dict[f]) for f in sorted(common_files)] + hyps = [normalize_text(hyp_dict[f]) for f in sorted(common_files)] + + cer_score = cer(refs, hyps) + wer_score = wer(refs, hyps) + print(f"CER: {cer_score:.3%}") + print(f"WER: {wer_score:.3%}") + print(f"Evaluated on {len(common_files)} files.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--reference", type=str, required=True, help="Path to reference JSONL" + ) + parser.add_argument( + "--hypothesis", type=str, required=True, help="Path to hypothesis JSONL" + ) + args = parser.parse_args() + main(args) diff --git a/codec/scripts/download_expresso.py b/codec/scripts/download_expresso.py new file mode 100644 index 0000000000000000000000000000000000000000..08c3683e7732f3b62629ffaced648966ae28a28d --- /dev/null +++ b/codec/scripts/download_expresso.py @@ -0,0 +1,10 @@ +import soundfile as sf + +from datasets import load_dataset + +dataset = load_dataset("ylacombe/expresso", split="train") +print(dataset) +for i, x in enumerate(dataset): + audio = x["audio"] + wav, sr = audio["array"], audio["sampling_rate"] + sf.write(f"expresso/org/{i}.wav", wav, sr) diff --git a/codec/scripts/download_gigaspeech.py b/codec/scripts/download_gigaspeech.py new file mode 100644 index 0000000000000000000000000000000000000000..ec3e7dbef9107f03e401687a5ad3cfb532146f57 --- /dev/null +++ b/codec/scripts/download_gigaspeech.py @@ -0,0 +1,14 @@ +from random import sample + +import soundfile as sf +from datasets import load_dataset + +# dataset = load_dataset("keithito/lj_speech", split="train") +#dataset = load_dataset("parler-tts/mls_eng", split="train") +dataset = load_dataset("speechcolab/gigaspeech", "xl", split="train", token=True) +Is = sample(list(range(len(dataset))), k=100000) +print(dataset) +for i, I in enumerate(Is): + audio = dataset[I]["audio"] + wav, sr = audio["array"], audio["sampling_rate"] + sf.write(f"gigaspeech/{I}.wav", wav, sr) diff --git a/codec/scripts/download_lj.py b/codec/scripts/download_lj.py new file mode 100644 index 0000000000000000000000000000000000000000..b702a9764e78297294b60f610cf3e2606569bc64 --- /dev/null +++ b/codec/scripts/download_lj.py @@ -0,0 +1,9 @@ +import soundfile as sf +from datasets import load_dataset + +dataset = load_dataset("keithito/lj_speech", split="train") +print(dataset) +for i, x in enumerate(dataset): + audio = x["audio"] + wav, sr = audio["array"], audio["sampling_rate"] + sf.write(f"ljspeech/{i}.wav", wav, sr) diff --git a/codec/scripts/download_ltts.py b/codec/scripts/download_ltts.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad7268a5eceb4554e095607f6831346d3fba3fd --- /dev/null +++ b/codec/scripts/download_ltts.py @@ -0,0 +1,16 @@ +from pathlib import Path + +import soundfile as sf + +from datasets import load_dataset + +dataset = load_dataset("mythicinfinity/libritts", "clean") +for split in dataset.keys(): + Path(f"libritts/{split}").mkdir(exist_ok=True) + for i, x in enumerate(dataset[split]): + # audio = x["audio"] + text = x["text_normalized"] + # wav, sr = audio["array"], audio["sampling_rate"] + # sf.write(f"libritts/{split}/{i}.wav", wav, sr) + with open(f"libritts/{split}/{i}.txt", "w") as f: + f.write(text) diff --git a/codec/scripts/download_mlseng10k.py b/codec/scripts/download_mlseng10k.py new file mode 100644 index 0000000000000000000000000000000000000000..3b25809a14587330fef9b6cba7d833fab77e3e52 --- /dev/null +++ b/codec/scripts/download_mlseng10k.py @@ -0,0 +1,13 @@ +from random import sample + +import soundfile as sf +from datasets import load_dataset + +# dataset = load_dataset("keithito/lj_speech", split="train") +dataset = load_dataset("parler-tts/mls_eng", split="train") +Is = sample(list(range(len(dataset))), k=100000) +print(dataset) +for i, I in enumerate(Is): + audio = dataset[I]["audio"] + wav, sr = audio["array"], audio["sampling_rate"] + sf.write(f"mls10keng/{i}.wav", wav, sr) diff --git a/codec/scripts/eval_asr.py b/codec/scripts/eval_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc131579e9580e0cf93c45949cf86f8cfd339c8 --- /dev/null +++ b/codec/scripts/eval_asr.py @@ -0,0 +1,100 @@ +import argparse +import json +from pathlib import Path + +import nemo.collections.asr as nemo_asr +import torch +import yaml +from jiwer import wer +from torchaudio import load +from torchaudio.functional import resample +from tqdm import tqdm + +from zcodec.models import WavVAE, ZFlowAutoEncoder + + +def load_config(config_path): + with open(config_path, "r") as f: + return yaml.safe_load(f) + + +def transcribe(audio: torch.Tensor, asr_model) -> str: + audio = audio.cpu().numpy(force=True) + with torch.inference_mode(): + return asr_model.transcribe([audio[0]])[0].text + + +def main(args): + config = load_config(args.config) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load models + wavvae = WavVAE.from_pretrained_local(config["wavvae_ckpt"]).to(device).eval() + zflowae = ( + ZFlowAutoEncoder.from_pretrained_local(config["zflowae_ckpt"]).to(device).eval() + ) + + # Load ASR model + asr_model = nemo_asr.models.ASRModel.from_pretrained( + model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2") + ) + + # Read file list + with open(config["file_list"], "r") as f: + wav_files = [line.strip() for line in f if line.strip()] + + results = [] + + for wav_path in tqdm(wav_files, desc="Processing files"): + wav, sr = load(wav_path) + wav = resample(wav, orig_freq=sr, new_freq=wavvae.sampling_rate).to(device) + + with torch.inference_mode(): + # Transcribe original + original_text = transcribe(wav, asr_model) + + # Compress and decompress + z = wavvae.encode(wav) + zz, _ = zflowae.encode(z) + z_hat = zflowae.decode( + zz, num_steps=config.get("num_steps", 10), cfg=config.get("cfg", 2.0) + ) + wav_hat = wavvae.decode(z_hat) + + # Transcribe reconstructed + reconstructed_text = transcribe(wav_hat, asr_model) + + results.append( + { + "file": wav_path, + "original_text": original_text, + "reconstructed_text": reconstructed_text, + } + ) + + # Save output + out_path = Path(config.get("output_jsonl", "transcripts.jsonl")) + with out_path.open("w") as f: + for entry in results: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + print(f"\nSaved {len(results)} transcript pairs to {out_path}") + + # Optionally compute WER + if args.compute_wer: + original_texts = [r["original_text"] for r in results] + reconstructed_texts = [r["reconstructed_text"] for r in results] + score = wer(original_texts, reconstructed_texts) + print(f"WER: {score:.3%}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True, help="Path to YAML config") + parser.add_argument( + "--compute_wer", action="store_true", help="Compute WER after decoding" + ) + args = parser.parse_args() + + main(args) diff --git a/codec/scripts/eval_asr_from_filelist.py b/codec/scripts/eval_asr_from_filelist.py new file mode 100644 index 0000000000000000000000000000000000000000..8764aa43ba1878c63ab9e370181556c66fed27d3 --- /dev/null +++ b/codec/scripts/eval_asr_from_filelist.py @@ -0,0 +1,60 @@ +import argparse +import json +from pathlib import Path + +import nemo.collections.asr as nemo_asr +import torch +import yaml +from torchaudio import load +from torchaudio.functional import resample +from tqdm import tqdm + + +def load_config(config_path): + with open(config_path, "r") as f: + return yaml.safe_load(f) + + +def transcribe(audio: torch.Tensor, asr_model) -> str: + audio = audio.cpu().numpy(force=True) + with torch.inference_mode(): + return asr_model.transcribe([audio[0]])[0].text + + +def main(args): + config = load_config(args.config) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load ASR model + asr_model = nemo_asr.models.ASRModel.from_pretrained( + model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2") + ) + + # Read file list + with open(config["file_list"], "r") as f: + wav_files = [line.strip() for line in f if line.strip()] + + results = [] + + for wav_path in tqdm(wav_files, desc="Transcribing"): + wav, sr = load(wav_path) + wav = resample(wav, orig_freq=sr, new_freq=16000).to(device) + + transcript = transcribe(wav, asr_model) + results.append({"file": wav_path, "transcript": transcript}) + + # Save output + out_path = Path(config.get("output_jsonl", "asr_transcripts.jsonl")) + with out_path.open("w") as f: + for entry in results: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + print(f"\nSaved {len(results)} transcripts to {out_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True, help="Path to YAML config") + args = parser.parse_args() + main(args) diff --git a/codec/scripts/eval_asr_only.py b/codec/scripts/eval_asr_only.py new file mode 100644 index 0000000000000000000000000000000000000000..5a3d9e15b37c0ac0604ab5d461d0f1b560c07d75 --- /dev/null +++ b/codec/scripts/eval_asr_only.py @@ -0,0 +1,84 @@ +import argparse +import json +from pathlib import Path + +import nemo.collections.asr as nemo_asr +import torch +import yaml +from torchaudio import load +from torchaudio.functional import resample +from tqdm import tqdm + +from zcodec.models import WavVAE, ZFlowAutoEncoder + + +def load_config(config_path): + with open(config_path, "r") as f: + return yaml.safe_load(f) + + +def transcribe(audio: torch.Tensor, asr_model) -> str: + audio = audio.cpu().numpy(force=True) + with torch.inference_mode(): + return asr_model.transcribe([audio[0]])[0].text + + +def main(args): + config = load_config(args.config) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load models + wavvae = WavVAE.from_pretrained_local(config["wavvae_ckpt"]).to(device).eval() + zflowae = ( + ZFlowAutoEncoder.from_pretrained_local(config["zflowae_ckpt"]).to(device).eval() + ) + + # Load ASR model + asr_model = nemo_asr.models.ASRModel.from_pretrained( + model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2") + ) + + # Read file list + with open(config["file_list"], "r") as f: + wav_files = [line.strip() for line in f if line.strip()] + + results = [] + + for wav_path in tqdm(wav_files, desc="ASR on reconstructed audio"): + wav, sr = load(wav_path) + wav = resample(wav, orig_freq=sr, new_freq=wavvae.sampling_rate).to(device) + + with torch.inference_mode(): + # Compress and decompress + z = wavvae.encode(wav) + zz, _ = zflowae.encode(z) + z_hat = zflowae.decode( + zz, num_steps=config.get("num_steps", 10), cfg=config.get("cfg", 2.0) + ) + wav_hat = wavvae.decode(z_hat) + + # Transcribe + wav_hat = resample(wav_hat, orig_freq=wavvae.sampling_rate, new_freq=16000) + reconstructed_text = transcribe(wav_hat, asr_model) + + results.append( + { + "file": wav_path, + "transcript": reconstructed_text, + } + ) + + # Save output + out_path = Path(config.get("output_jsonl", "asr_reconstructed.jsonl")) + with out_path.open("w") as f: + for entry in results: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + print(f"\nSaved {len(results)} reconstructed ASR results to {out_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True, help="Path to YAML config") + args = parser.parse_args() + main(args) diff --git a/codec/scripts/export_wavvae.py b/codec/scripts/export_wavvae.py new file mode 100644 index 0000000000000000000000000000000000000000..62454ebea83500efee45fa89f69a4ac511935163 --- /dev/null +++ b/codec/scripts/export_wavvae.py @@ -0,0 +1,53 @@ +import argparse +import hashlib +from pathlib import Path + +from zcodec.trainers import TrainWavVAE + + +def hash_checkpoint_file(path, method="sha256", length=8): + h = hashlib.new(method) + with open(path, "rb") as f: + while chunk := f.read(8192): + h.update(chunk) + return h.hexdigest()[:length] + + +def main(): + parser = argparse.ArgumentParser(description="Export WavVAE pretrained checkpoint.") + parser.add_argument( + "checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)" + ) + parser.add_argument( + "--output", + type=Path, + default=None, + help="Optional output directory (default is based on config)", + ) + args = parser.parse_args() + + # Load Lightning module + wavvae = TrainWavVAE.load_from_checkpoint(args.checkpoint) + config = wavvae.hparams.config + + # Compute framerate and z_dim + frame_rate = wavvae.wavvae.frame_rate + z_dim = config.latent_dim + + checkpoint_hash = hash_checkpoint_file(args.checkpoint) + # Determine output directory + if args.output is None: + out_dir = Path( + f"checkpoints/wavvae/pretrained/WavVAE_framerate{frame_rate}_zdim{z_dim}_{checkpoint_hash}" + ) + else: + out_dir = args.output + out_dir.mkdir(parents=True, exist_ok=True) + + # Save weights and config + wavvae.save_model_weights_and_config(str(out_dir)) + print(f"✅ Exported model to {out_dir}") + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/export_zflowae.py b/codec/scripts/export_zflowae.py new file mode 100644 index 0000000000000000000000000000000000000000..55f858e265e2e6164170c3651d932cb8d1cc17e0 --- /dev/null +++ b/codec/scripts/export_zflowae.py @@ -0,0 +1,63 @@ +import argparse +import hashlib +from pathlib import Path + +from zcodec.trainers import TrainZFlowAutoEncoder + + +def hash_checkpoint_file(path, method="sha256", length=8): + h = hashlib.new(method) + with open(path, "rb") as f: + while chunk := f.read(8192): + h.update(chunk) + return h.hexdigest()[:length] + + +def main(): + parser = argparse.ArgumentParser( + description="Export ZFlowAutoEncoder pretrained checkpoint." + ) + parser.add_argument( + "checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)" + ) + parser.add_argument( + "--output", + type=Path, + default=None, + help="Optional output directory (default is based on config)", + ) + args = parser.parse_args() + + zflowae = TrainZFlowAutoEncoder.load_from_checkpoint(args.checkpoint) + checkpoint_hash = hash_checkpoint_file(args.checkpoint) + config = zflowae.hparams.config + repa = False + if hasattr(zflowae.hparams, "ssl_repa_head"): + if zflowae.hparams.ssl_repa_head: + repa = True + stride = config.latent_stride + zdim = config.latent_dim + ae_factory = config.autoencoder_factory + if config.fsq_levels is not None: + codebook_size = 1 + for c in config.fsq_levels: + codebook_size *= c + type = f"fsq{codebook_size}" + elif config.vae: + type = f"vae{config.bottleneck_size}" + + if args.output is None: + out_dir = Path( + f"checkpoints/zflowae/pretrained/ZFlowAutoEncoder_stride{stride}_zdim{zdim}_{type}_{ae_factory}_repa{repa}_{checkpoint_hash}" + ) + else: + out_dir = args.output + out_dir.mkdir(parents=True, exist_ok=True) + + # Save weights and config + zflowae.save_model_weights_and_config(str(out_dir)) + print(f"✅ Exported model to {out_dir}") + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/infer_hubert.py b/codec/scripts/infer_hubert.py new file mode 100644 index 0000000000000000000000000000000000000000..35c89b3b863fb0c0d24e14d2455c3cb74826136b --- /dev/null +++ b/codec/scripts/infer_hubert.py @@ -0,0 +1,101 @@ +import argparse +from pathlib import Path + +import librosa +import soundfile as sf +import torch +from safetensors.torch import save_file +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +from transformers import HubertModel + + +class AudioDataset(Dataset): + def __init__(self, file_list, target_sr=16000): + self.paths = file_list + self.target_sr = target_sr + + def __len__(self): + return len(self.paths) + + def __getitem__(self, idx): + path = self.paths[idx] + wav, sr = sf.read(str(path)) + if sr != self.target_sr: + wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) + wav = torch.tensor(wav).float().unsqueeze(0) # shape: [1, T] + return wav, path + + +@torch.no_grad() +def encode_batch(model, batch, device, out_dir, keep_layers): + wavs, paths = batch + for wav, path in zip(wavs, paths): + wav = wav.to(device) + outputs = model(wav, output_hidden_states=True) + hidden_states = outputs.hidden_states # tuple of 13 tensors: [1, T', D] + selected = { + f"layer_{i}": hs.squeeze(0).cpu() + for i, hs in enumerate(hidden_states) + if i in keep_layers + } + out_path = out_dir / (path.stem + ".st") + save_file(selected, str(out_path)) + + +def parse_layers(layer_str, max_layers): + if layer_str.strip().lower() == "all": + return set(range(max_layers)) + return set(int(idx) for idx in layer_str.split(",") if idx.strip().isdigit()) + + +def main(): + parser = argparse.ArgumentParser(description="Infer HuBERT hidden states.") + parser.add_argument( + "file_list", type=Path, help="Text file with paths to audio files" + ) + parser.add_argument("output_dir", type=Path, help="Directory to save .st files") + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument("--num_workers", type=int, default=2) + parser.add_argument( + "--layers", + type=str, + default="all", + help="Comma-separated layer indices or 'all'", + ) + + args = parser.parse_args() + device = torch.device(args.device) + + model = HubertModel.from_pretrained("facebook/hubert-base-ls960").to(device).eval() + num_layers = ( + len(model.config.hidden_layers) + if hasattr(model.config, "hidden_layers") + else 13 + ) + keep_layers = parse_layers(args.layers, num_layers) + + with open(args.file_list, "r") as f: + paths = [Path(line.strip()) for line in f if line.strip()] + + dataset = AudioDataset(paths) + dataloader = DataLoader( + dataset, + batch_size=1, + num_workers=args.num_workers, + collate_fn=lambda x: list(zip(*x)), + ) + + args.output_dir.mkdir(parents=True, exist_ok=True) + + for batch in tqdm(dataloader): + try: + encode_batch(model, batch, device, args.output_dir, keep_layers) + except Exception as e: + print(f"❌ Failed on batch: {e}") + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/infer_wav2vec.py b/codec/scripts/infer_wav2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..e973efe546c16958b9d3c69360f613ff66b57f6a --- /dev/null +++ b/codec/scripts/infer_wav2vec.py @@ -0,0 +1,100 @@ +import argparse +from pathlib import Path + +import librosa +import soundfile as sf +import torch +from safetensors.torch import save_file +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +from transformers import Wav2Vec2Model + + +class AudioDataset(Dataset): + def __init__(self, file_list, target_sr=16000): + self.paths = file_list + self.target_sr = target_sr + + def __len__(self): + return len(self.paths) + + def __getitem__(self, idx): + path = self.paths[idx] + wav, sr = sf.read(str(path)) + if sr != self.target_sr: + wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) + wav = torch.tensor(wav).float().unsqueeze(0) # shape: [1, T] + return wav, path + + +@torch.no_grad() +def encode_batch(model, batch, device, out_dir, keep_layers): + wavs, paths = batch + for wav, path in zip(wavs, paths): + wav = wav.to(device) + outputs = model(wav, output_hidden_states=True) + hidden_states = outputs.hidden_states # tuple of 25 tensors: [1, T', D] + selected = { + f"layer_{i}": hs.squeeze(0).cpu() + for i, hs in enumerate(hidden_states) + if i in keep_layers + } + out_path = out_dir / (path.stem + ".st") + save_file(selected, str(out_path)) + + +def parse_layers(layer_str): + if layer_str.strip().lower() == "all": + return set(range(25)) + return set(int(idx) for idx in layer_str.split(",") if idx.strip().isdigit()) + + +def main(): + parser = argparse.ArgumentParser(description="Infer Wav2Vec2 hidden states.") + parser.add_argument( + "file_list", type=Path, help="Text file with paths to audio files" + ) + parser.add_argument("output_dir", type=Path, help="Directory to save .st files") + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument("--num_workers", type=int, default=2) + parser.add_argument( + "--layers", + type=str, + default="all", + help="Comma-separated layer indices or 'all'", + ) + + args = parser.parse_args() + keep_layers = parse_layers(args.layers) + device = torch.device(args.device) + + model = ( + Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53") + .to(device) + .eval() + ) + + with open(args.file_list, "r") as f: + paths = [Path(line.strip()) for line in f if line.strip()] + + dataset = AudioDataset(paths) + dataloader = DataLoader( + dataset, + batch_size=1, + num_workers=args.num_workers, + collate_fn=lambda x: list(zip(*x)), + ) + + args.output_dir.mkdir(parents=True, exist_ok=True) + + for batch in tqdm(dataloader): + try: + encode_batch(model, batch, device, args.output_dir, keep_layers) + except Exception as e: + print(f"❌ Failed on batch: {e}") + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/infer_wavvae.py b/codec/scripts/infer_wavvae.py new file mode 100644 index 0000000000000000000000000000000000000000..6de6a280e15aaa1549a074cf137369e789ea9348 --- /dev/null +++ b/codec/scripts/infer_wavvae.py @@ -0,0 +1,93 @@ +import argparse +from pathlib import Path + +import librosa +import soundfile as sf +import torch +from safetensors.torch import save_file +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from zcodec.models import WavVAE + + +class AudioDataset(Dataset): + def __init__(self, file_list, target_sr): + self.paths = file_list + self.target_sr = target_sr + + def __len__(self): + return len(self.paths) + + def __getitem__(self, idx): + path = self.paths[idx] + wav, sr = sf.read(str(path)) + if sr != self.target_sr: + wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) + wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T] + return wav, path + + +@torch.no_grad() +def encode_batch(model, batch, device, out_dir): + wavs, paths = batch + for wav, path in zip(wavs, paths): + wav = wav.to(device) + latent = model.encode(wav).cpu() + out_path = out_dir / (path.stem + ".st") + save_file({"audio_z": latent}, str(out_path)) + + +def main(): + parser = argparse.ArgumentParser( + description="Batch encode audio files with WavVAE." + ) + parser.add_argument( + "file_list", + type=Path, + help="Text file listing paths to audio files (one per line)", + ) + parser.add_argument( + "checkpoint", type=Path, help="Path to WavVAE checkpoint directory" + ) + parser.add_argument( + "output_dir", type=Path, help="Directory to save Safetensors latents" + ) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument( + "--num_workers", type=int, default=3, help="Number of DataLoader workers" + ) + + args = parser.parse_args() + device = torch.device(args.device) + + # Load model + wavvae = WavVAE.from_pretrained_local(args.checkpoint) + wavvae = wavvae.to(device).eval() + target_sr = wavvae.sampling_rate + + # Prepare dataset and dataloader + with open(args.file_list, "r") as f: + file_paths = [Path(line.strip()) for line in f if line.strip()] + dataset = AudioDataset(file_paths, target_sr) + dataloader = DataLoader( + dataset, + batch_size=1, + num_workers=args.num_workers, + collate_fn=lambda x: list(zip(*x)), + ) + + args.output_dir.mkdir(parents=True, exist_ok=True) + + # Inference loop + for batch in tqdm(dataloader): + try: + encode_batch(wavvae, batch, device, args.output_dir) + except Exception as e: + print(f"❌ Batch failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/infer_wavvae_audiocite.py b/codec/scripts/infer_wavvae_audiocite.py new file mode 100644 index 0000000000000000000000000000000000000000..a91f901a149e5b595301019f74702e201c2bdfce --- /dev/null +++ b/codec/scripts/infer_wavvae_audiocite.py @@ -0,0 +1,73 @@ +import argparse +from pathlib import Path + +import librosa +import soundfile as sf +import torch + +from datasets import load_dataset, load_from_disk +from zcodec.models import WavVAE + + +def load_and_resample(path, target_sr): + wav, sr = sf.read(str(path)) + if sr != target_sr: + wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) + wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T] + return wav + + +def main(): + parser = argparse.ArgumentParser( + description="Encode HF dataset audio with WavVAE using map() (non-batched)." + ) + parser.add_argument("dataset", type=str, help="Path or HF hub ID of dataset") + parser.add_argument("path_column", type=str, help="Column name with wav file paths") + parser.add_argument( + "checkpoint", type=Path, help="Path to WavVAE checkpoint directory" + ) + parser.add_argument( + "--split", type=str, default=None, help="Dataset split (if loading from hub)" + ) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument( + "--num_proc", type=int, default=1, help="Number of processes for map()" + ) + args = parser.parse_args() + + device = torch.device(args.device) + + # Load model + wavvae = WavVAE.from_pretrained_local(args.checkpoint).to(device).eval() + target_sr = wavvae.sampling_rate + + # Load dataset + if Path(args.dataset).exists(): + ds = load_from_disk(args.dataset) + else: + ds = load_dataset(args.dataset, split=args.split or "train") + + ds = ds.filter(lambda x: x > 1.0, input_columns="duration") + + # Mapping function (non-batched) + @torch.no_grad() + def encode_example(example): + wav = load_and_resample(example[args.path_column], target_sr).to(device) + latent = wavvae.encode(wav).cpu().numpy() + example["audio_z"] = latent + return example + + # Apply map without batching + ds = ds.map( + encode_example, + num_proc=args.num_proc, + ) + + # Save dataset with new column + ds.save_to_disk(str(Path(args.dataset) + "_with_latents")) + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/infer_wavvae_hf.py b/codec/scripts/infer_wavvae_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..58415077977d545201c6d42d27e6244ccf12bd9d --- /dev/null +++ b/codec/scripts/infer_wavvae_hf.py @@ -0,0 +1,146 @@ +import argparse +from pathlib import Path + +import librosa +import soundfile as sf +import torch +import torchaudio +from torchaudio.functional import resample +from safetensors.torch import save_file +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from datasets import Audio, load_dataset, load_from_disk +from zcodec.models import WavVAE + + +class AudioDataset(Dataset): + def __init__(self, file_list, target_sr): + self.paths = file_list + self.target_sr = target_sr + + def __len__(self): + return len(self.paths) + + def __getitem__(self, idx): + path = self.paths[idx] + wav, sr = sf.read(str(path)) + if sr != self.target_sr: + wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) + wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T] + return wav, path + + +@torch.no_grad() +def encode_batch(model, batch, device, out_dir): + wavs, paths = batch + for wav, path in zip(wavs, paths): + wav = wav.to(device) + latent = model.encode(wav).cpu() + out_path = out_dir / (path.stem + ".st") + save_file({"audio_z": latent}, str(out_path)) + + +def main(): + parser = argparse.ArgumentParser( + description="Batch encode audio files with WavVAE." + ) + parser.add_argument( + "input_dataset", + type=Path, + help="Text file listing paths to audio files (one per line)", + ) + parser.add_argument( + "checkpoint", type=Path, help="Path to WavVAE checkpoint directory" + ) + parser.add_argument( + "output_dataset", type=Path, help="Directory to save Safetensors latents" + ) + parser.add_argument("--in_column", type=str, default="audio") + parser.add_argument("--out_column", type=str, default="audio_z") + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument("--split", type=str, default=None) + parser.add_argument( + "--num_workers", type=int, default=1, help="Number of DataLoader workers" + ) + parser.add_argument( + "--num_shards", type=int, default=1, help="Number of DataLoader workers" + ) + parser.add_argument( + "--shard_index", type=int, default=0, help="Number of DataLoader workers" + ) + parser.add_argument("--from_file", action="store_true") + parser.add_argument("--from_files", action="store_true") + parser.add_argument("--no_resample", action="store_false") + parser.add_argument("--from_hub", action="store_true") + parser.add_argument("--file_prefix", type=str, default=None) + args = parser.parse_args() + device = torch.device(args.device) + + # Load model + wavvae = WavVAE.from_pretrained_local(args.checkpoint) + wavvae = wavvae.to(device).eval() + target_sr = wavvae.sampling_rate + # Prepare dataset and dataloader + if args.from_hub: + dataset = load_dataset(str(args.input_dataset), args.split) + else: + dataset = load_from_disk(str(args.input_dataset), args.split) + # + + if args.num_shards > 1: + dataset = dataset.shard(num_shards=args.num_shards, index=args.shard_index) + if args.from_file: + def map_fn(audio_file_path): + if args.file_prefix is not None: + audio_file_path = args.file_prefix + "/" + audio_file_path + wav, sr = torchaudio.load(audio_file_path) + wav = resample(wav, sr, target_sr) + wav = wav.mean(dim=0, keepdim=True) + if not args.no_resample: + wav = resample(wav, sr, target_sr) + with torch.inference_mode(): + latent = wavvae.encode(wav.to(device)) + return {"audio_z": latent} + + dataset = dataset.map(map_fn, input_columns=args.in_column) + elif args.from_files: + def map_fn(audio_file_paths): + if args.file_prefix is not None: + audio_file_paths = [args.file_prefix + "/" + x for x in audio_file_paths] + wav, sr = torchaudio.load(audio_file_paths[0]) + wavs = [wav.mean(dim=0, keepdim=True)] + wavs = wavs + [torchaudio.load(x)[0].mean(dim=0, keepdim=True) for x in audio_file_paths[1:]] + wav = torch.cat(wavs, dim=1) + if not args.no_resample: + wav = resample(wav, sr, target_sr) + with torch.inference_mode(): + latent = wavvae.encode(wav.to(device)) + return {"audio_z": latent} + + dataset = dataset.map(map_fn, input_columns=args.in_column) + + else: + dataset = dataset.cast_column(args.in_column, Audio(sampling_rate=target_sr)) + dataset = dataset.with_format( + "torch", + columns=args.in_column, + ) + + def map_fn(audio): + with torch.inference_mode(): + wav = audio["array"].unsqueeze(0).to(device) + latent = wavvae.encode(wav) + return {"audio_z": latent} + + dataset = dataset.map( + map_fn, input_columns=args.in_column, remove_columns=args.in_column + ) + + dataset.save_to_disk(args.output_dataset) + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/infer_zflowae.py b/codec/scripts/infer_zflowae.py new file mode 100644 index 0000000000000000000000000000000000000000..16fddaebea1f848976f12522ceac0526ce014cf5 --- /dev/null +++ b/codec/scripts/infer_zflowae.py @@ -0,0 +1,121 @@ +import argparse +from pathlib import Path + +import librosa +import soundfile as sf +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from zcodec.models import ZFlowAutoEncoder + + +class AudioDataset(Dataset): + def __init__(self, file_list, target_sr): + self.paths = file_list + self.target_sr = target_sr + + def __len__(self): + return len(self.paths) + + def __getitem__(self, idx): + path = self.paths[idx] + wav, sr = sf.read(str(path)) + if sr != self.target_sr: + wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) + wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T] + return wav, path + + +class SafeTensorDataset(Dataset): + """ + On __getitem__, opens the safetensor, uses get_slice() to inspect shape, + then either drops too-short files (return None) or returns a random subsequence slice. + """ + + def __init__( + self, + file_paths: list[str], + key: str = "audio_z", + ): + self.file_paths = file_paths + self.key = key + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx: int) -> torch.Tensor | None: + path = self.file_paths[idx] + # open file, get a slice wrapper for full tensor + with safe_open(path, framework="pt") as f: + tensor = f.get_tensor(self.key) + + return tensor, path + + +@torch.no_grad() +def encode_batch(model, batch, device, out_dir, save_latent=False): + wavs, paths = batch + for wav, path in zip(wavs, paths): + wav = wav.to(device) + latent, indices = model.encode(wav) + if save_latent: + to_save = latent.cpu() + else: + to_save = indices.cpu() + out_path = out_dir / (path.stem + ".st") + save_file({"audio_z": to_save}, str(out_path)) + + +def main(): + parser = argparse.ArgumentParser( + description="Batch encode audio files with ZFlowAutoEncoder." + ) + parser.add_argument( + "file_list", + type=Path, + help="Text file listing paths to audio files (one per line)", + ) + parser.add_argument( + "checkpoint", type=Path, help="Path to ZFlowAutoEncoder checkpoint directory" + ) + parser.add_argument( + "output_dir", type=Path, help="Directory to save Safetensors latents" + ) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + + parser.add_argument("--num_workers", type=int, default=4, help="Num. workers") + args = parser.parse_args() + device = torch.device(args.device) + + # Load model + zflowae = ZFlowAutoEncoder.from_pretrained_local(args.checkpoint) + zflowae = zflowae.to(device).eval() + + # Prepare dataset and dataloader + with open(args.file_list, "r") as f: + file_paths = [Path(line.strip()) for line in f if line.strip()] + dataset = SafeTensorDataset(file_paths) + dataloader = DataLoader( + dataset, + batch_size=1, + num_workers=args.num_workers, + collate_fn=lambda x: list(zip(*x)), + ) + + args.output_dir.mkdir(parents=True, exist_ok=True) + + # Inference loop + for batch in tqdm(dataloader): + try: + encode_batch(zflowae, batch, device, args.output_dir) + except Exception as e: + print(f"❌ Batch failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/codec/scripts/infer_zflowae_hf.py b/codec/scripts/infer_zflowae_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..911f203422320b39336576006797b4f382ba8c8e --- /dev/null +++ b/codec/scripts/infer_zflowae_hf.py @@ -0,0 +1,121 @@ +import argparse +from pathlib import Path + +import librosa +import soundfile as sf +import torch +import torchaudio +from safetensors.torch import save_file +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +from datasets import Audio, load_dataset, load_from_disk +from zcodec.models import WavVAE, ZFlowAutoEncoder + + +class AudioDataset(Dataset): + def __init__(self, file_list, target_sr): + self.paths = file_list + self.target_sr = target_sr + + def __len__(self): + return len(self.paths) + + def __getitem__(self, idx): + path = self.paths[idx] + wav, sr = sf.read(str(path)) + if sr != self.target_sr: + wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr) + wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T] + return wav, path + + +@torch.no_grad() +def encode_batch(model, batch, device, out_dir, save_latent=True): + wavs, paths = batch + for wav, path in zip(wavs, paths): + wav = wav.to(device) + latent, indices = model.encode(wav) + if save_latent: + to_save = latent.cpu() + else: + save_latent = indices.cpu() + out_path = out_dir / (path.stem + ".st") + save_file({"audio_z": to_save}, str(out_path)) + + +def main(): + parser = argparse.ArgumentParser( + description="Batch encode audio files with WavVAE." + ) + parser.add_argument( + "input_dataset", + type=Path, + help="Text file listing paths to audio files (one per line)", + ) + parser.add_argument( + "checkpoint", type=Path, help="Path to zflowae checkpoint directory" + ) + parser.add_argument( + "output_dataset", type=Path, help="Directory to save Safetensors latents" + ) + parser.add_argument("--in_column", type=str, default="audio_z") + parser.add_argument("--out_column", type=str, default="audio_latent") + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument("--split", type=str, default="all") + parser.add_argument( + "--num_workers", type=int, default=1, help="Number of DataLoader workers" + ) + parser.add_argument("--from_file", action="store_true") + parser.add_argument("--from_hub", action="store_true") + parser.add_argument("--file_prefix", type=str, default=None) + args = parser.parse_args() + device = torch.device(args.device) + + # Load model + zflowae = ZFlowAutoEncoder.from_pretrained_local(args.checkpoint) + zflowae = zflowae.to(device).eval() + + # Prepare dataset and dataloader + if args.from_hub: + dataset = load_dataset(str(args.input_dataset), args.split) + else: + dataset = load_from_disk(str(args.input_dataset), args.split) + # + + if args.from_file: + raise NotImplemented + + def map_fn(audio_file_path): + if args.file_prefix is not None: + audio_file_path = args.file_prefix+"/"+audio_file_path + wav, sr = torchaudio.load(audio_file_path) + wav = wav.mean(dim=0, keepdim=True) + with torch.inference_mode(): + latent = zflowae.encode(wav.to(device)) + return {"audio_z": latent} + + dataset = dataset.map(map_fn, input_columns=args.in_column) + else: + dataset = dataset.with_format( + "torch", + columns=args.in_column, + ) + + def map_fn(audio): + with torch.inference_mode(): + audio_z = audio.to(device) + latent, _ = zflowae.encode(audio_z) + return {args.out_column: latent} + + dataset = dataset.map( + map_fn, input_columns=args.in_column, remove_columns=args.in_column + ) + + dataset.save_to_disk(args.output_dataset) + + +if __name__ == "__main__": + main() diff --git a/codec/train_patchvae.py b/codec/train_patchvae.py new file mode 100755 index 0000000000000000000000000000000000000000..40bfbf22495f6a30cbd9a4d967e385587b8ca723 --- /dev/null +++ b/codec/train_patchvae.py @@ -0,0 +1,232 @@ +import json +import math +from dataclasses import asdict +from pathlib import Path +from typing import Literal + +import pytorch_lightning as pl +import torch +from safetensors.torch import safe_open, save_file +from torch import nn +from torch.optim.lr_scheduler import LambdaLR +from torchaudio.transforms import Resample +from transformers import WavLMModel + +from .models import PatchVAE, PatchVAEConfig, WavVAE +from .models.components.convnext import SwiGLU +from .models.patchvae.modules import convnext_factory + + +def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): + def lr_lambda(step): + if step < warmup_steps: + return step / max(1, warmup_steps) + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) + return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr + + return lr_lambda + + +class TrainPatchVAE(pl.LightningModule): + def __init__( + self, + config: PatchVAEConfig, + lr: float = 1e-4, + end_lr: float | None = None, + weight_decay: float = 0.01, + cfg_drop_rate: float = 0.0, + cfg_drop_rate_per_sample: float = 0.1, + total_steps: int = 2_000_000, + warmup_steps: int = 0, + mean_std_from_file: str | None = None, + train_sigma: float = 1e-4, + kl_div_factor: float = 1e-4, + wavvae_pretrained_path: str | None = None, + drop_vae_rate: float = 0.0, + ssl_repa_factor: float = 1.0, + ssl_repa_head: bool = False, + ssl_repa_head_type: Literal["convnext", "swiglu"] = "swiglu", + ssl_repa_head_target: Literal["flow", "prior"] = "prior", + ssl_repa_head_num_layers: int = 4, + ssl_repa_source: Literal["wavlm-base-plus"] = "wavlm-base-plus", + ssl_feat_dim: int = 768, + ): + super().__init__() + self.save_hyperparameters() + + if mean_std_from_file is not None: + with safe_open(mean_std_from_file, framework="pt") as f: + config.latent_scaling = ( + f.get_tensor("mean").tolist(), + f.get_tensor("std").tolist(), + ) + + self.patchvae = PatchVAE(config) + self.apply(self._init_weights) + + self.wavvae = None + if wavvae_pretrained_path is not None: + self.wavvae = WavVAE.from_pretrained_local(wavvae_pretrained_path).eval() + + self.ssl_model = None + if ssl_repa_head: + if ssl_repa_source == "wavlm-base-plus": + self.ssl_model = WavLMModel.from_pretrained( + "microsoft/wavlm-base-plus" + ).eval() + self.ssl_resampling = Resample( + orig_freq=self.wavvae.sampling_rate, new_freq=16000 + ) + + for p in self.ssl_model.parameters(): + p.requires_grad = False + + if ssl_repa_head_type == "convnext": + self.ssl_repa_head = nn.Sequential( + convnext_factory( + config.hidden_dim, + ssl_repa_head_num_layers, + ), + nn.Linear(config.hidden_dim, ssl_feat_dim), + ) + elif ssl_repa_head_type == "swiglu": + self.ssl_repa_head = nn.Sequential( + *[ + SwiGLU(config.hidden_dim) + for _ in range(ssl_repa_head_num_layers) + ], + nn.Linear(config.hidden_dim, ssl_feat_dim), + ) + + else: + self.ssl_repa_head = None + + def _init_weights(self, m: nn.Module): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def configure_optimizers(self): + params = [{"params": self.patchvae.parameters()}] + if self.ssl_repa_head is not None: + params += [{"params": self.ssl_repa_head.parameters()}] + opt = torch.optim.AdamW( + params, + lr=self.hparams.lr, + weight_decay=self.hparams.weight_decay, + ) + + if self.hparams.end_lr is not None: + scheduler = LambdaLR( + opt, + lr_lambda=cosine_schedule_with_warmup( + warmup_steps=self.hparams.warmup_steps, + total_steps=self.hparams.total_steps, + start_lr=self.hparams.lr, + end_lr=self.hparams.end_lr, + ), + ) + scheduler = {"scheduler": scheduler, "interval": "step"} + return [opt], [scheduler] + return opt + + def save_model_weights_and_config( + self, + dir: str | None, + model_filename: str = "model.st", + config_filename: str = "config.json", + ): + cfg = self.hparams.config + model_path = Path(dir) / model_filename + save_file(self.patchvae.state_dict(), model_path.with_suffix(".st")) + with open(Path(dir) / config_filename, "w") as f: + json.dump(asdict(cfg), f, indent=2) + + def training_step(self, batch: dict[str, torch.Tensor], batch_idx): + z = batch["audio_z"] + t = torch.rand(z.shape[0], device=z.device) + drop_cond_rate = self.hparams.cfg_drop_rate_per_sample + drop_vae_rate = self.hparams.drop_vae_rate + flow_loss, ae_loss, prior = self.patchvae( + z, + t, + sigma=self.hparams.train_sigma, + drop_vae_rate=drop_vae_rate, + drop_cond_rate=drop_cond_rate, + ) + + self.log("train_flow_loss", flow_loss, prog_bar=True) + + total_loss = flow_loss + if ae_loss.get("kl_div") is not None: + kl_div = ae_loss.get("kl_div") + self.log("train_kl_div", kl_div, prog_bar=True) + total_loss += self.hparams.kl_div_factor * kl_div + + for x in ["_mu_mean", "_mu_std", "_logvar_mean", "_logvar_std"]: + stat = ae_loss.get(x) + if stat is not None: + self.log(x, stat, prog_bar=False) + + if self.ssl_repa_head is not None: + target = self.hparams.ssl_repa_head_target + if target == "prior": + target = prior + elif target == "flow": + raise NotImplementedError + with torch.inference_mode(): + wav = self.wavvae.decode(z) + wav = self.ssl_resampling(wav) + wav = torch.nn.functional.pad(wav, (40, 40)) + ssl_feats = self.ssl_model(wav, output_hidden_states=True).hidden_states + ssl_feat = ssl_feats[10] + ssl_feat = torch.nn.functional.avg_pool1d( + ssl_feat.transpose(-1, -2), + kernel_size=8, + stride=4, + padding=2, + ).transpose(-1, -2) + + ssl_feat = ssl_feat.clone() + + B, N, D = ssl_feat.shape + + repa_pred = self.ssl_repa_head(target) + ssl_repa_loss = nn.functional.cosine_embedding_loss( + repa_pred.reshape(-1, D), + ssl_feat.reshape(-1, D), + torch.ones(1).to(repa_pred), + ) + total_loss += self.hparams.ssl_repa_factor * ssl_repa_loss + self.log("train_repa_loss", ssl_repa_loss, prog_bar=True) + + return total_loss + + def validation_step(self, batch: dict[str, torch.Tensor], batch_idx): + z = batch["audio_z"] + t = ( + torch.ones(z.shape[0], device=z.device) + * batch_idx + / self.trainer.num_val_batches[0] + ) + flow_loss, ae_loss, prior = self.patchvae( + z, t, sigma=self.hparams.train_sigma, drop_cond=False + ) + self.log("val_flow_loss", flow_loss, prog_bar=True) + + total_loss = flow_loss + if ae_loss.get("kl_div") is not None: + kl_div = ae_loss.get("kl_div") + self.log("val_kl_div", kl_div, prog_bar=True) + total_loss += self.hparams.kl_div_factor * kl_div + return total_loss + +if __name__ == "__main__": + from pytorch_lightning.cli import LightningCLI + + LightningCLI( + TrainPatchVAE, + save_config_callback=None, + parser_kwargs={"parser_mode": "omegaconf"}, + ) diff --git a/codec/train_wavvae.py b/codec/train_wavvae.py new file mode 100755 index 0000000000000000000000000000000000000000..0392d3b846ef8760f306363a701d093a28e818e2 --- /dev/null +++ b/codec/train_wavvae.py @@ -0,0 +1,304 @@ +import json +import math +from dataclasses import asdict +from pathlib import Path +from typing import Optional + +import pytorch_lightning as pl +import torch +import transformers +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.cli import LightningCLI +from pytorch_lightning.loggers.wandb import WandbLogger +from safetensors.torch import save_file +from torch.nn.utils import clip_grad_norm_ + +from codec.models import WavVAE, WavVAEConfig +from codec.models.wavvae.discriminators import (MultiPeriodDiscriminator, + MultiResolutionDiscriminator) +from codec.models.wavvae.loss import (DiscriminatorLoss, FeatureMatchingLoss, + GeneratorLoss, + MelSpecReconstructionLoss) + + +class TrainWavVAE(pl.LightningModule): + def __init__( + self, + config: WavVAEConfig, + sample_rate: int, + initial_learning_rate: float, + num_warmup_steps: int = 0, + mel_loss_coeff: float = 45, + mrd_loss_coeff: float = 1.0, + kl_div_coeff: float = 1e-5, + pretrain_mel_steps: int = 0, + decay_mel_coeff: bool = False, + clip_grad_norm: float | None = None, + f_min: int = 0, + f_max: Optional[int] = None, + mrd_fft_sizes: tuple[int, int, int] = (2048, 1024, 512), + mel_hop_length: int = 256, + log_audio_every_n_epoch: int = 5, + log_n_audio_batches: int = 32, + ): + super().__init__() + + self.save_hyperparameters() + self.wavvae = WavVAE(config) + self.multiperioddisc = MultiPeriodDiscriminator() + self.multiresddisc = MultiResolutionDiscriminator( + fft_sizes=tuple(mrd_fft_sizes) + ) + + self.disc_loss = DiscriminatorLoss() + self.gen_loss = GeneratorLoss() + self.feat_matching_loss = FeatureMatchingLoss() + self.melspec_loss = MelSpecReconstructionLoss( + sample_rate=sample_rate, + f_min=f_min, + f_max=f_max, + hop_length=mel_hop_length, + ) + + self.train_discriminator = False + self.automatic_optimization = False + self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff + + def save_model_weights_and_config( + self, + dir: str | None, + model_filename: str = "model.st", + config_filename: str = "config.json", + ): + cfg = self.hparams.config + model_path = Path(dir) / model_filename + save_file(self.wavvae.state_dict(), model_path) + with open(Path(dir) / config_filename, "w") as f: + json.dump(asdict(cfg), f, indent=2) + + def configure_optimizers(self): + disc_params = [ + {"params": self.multiperioddisc.parameters()}, + {"params": self.multiresddisc.parameters()}, + ] + gen_params = [ + {"params": self.wavvae.parameters()}, + ] + + opt_disc = torch.optim.AdamW( + disc_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9) + ) + opt_gen = torch.optim.AdamW( + gen_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9) + ) + + max_steps = self.trainer.max_steps // 2 + scheduler_disc = transformers.get_cosine_schedule_with_warmup( + opt_disc, + num_warmup_steps=self.hparams.num_warmup_steps, + num_training_steps=max_steps, + ) + scheduler_gen = transformers.get_cosine_schedule_with_warmup( + opt_gen, + num_warmup_steps=self.hparams.num_warmup_steps, + num_training_steps=max_steps, + ) + + return ( + [opt_disc, opt_gen], + [ + {"scheduler": scheduler_disc, "interval": "step"}, + {"scheduler": scheduler_gen, "interval": "step"}, + ], + ) + + def forward(self, audio_input, **kwargs): + audio_output, kl_div = self.wavvae(audio_input) + + return audio_output, kl_div + + def training_step(self, batch, batch_idx, **kwargs): + audio_input = batch + + opt_disc, opt_gen = self.optimizers() + + if self.train_discriminator: + with torch.no_grad(): + audio_hat, kl_div = self(audio_input, **kwargs) + + real_score_mp, gen_score_mp, _, _ = self.multiperioddisc( + y=audio_input, + y_hat=audio_hat, + **kwargs, + ) + real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc( + y=audio_input, + y_hat=audio_hat, + **kwargs, + ) + loss_mp, loss_mp_real, _ = self.disc_loss( + disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp + ) + loss_mrd, loss_mrd_real, _ = self.disc_loss( + disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd + ) + loss_mp /= len(loss_mp_real) + loss_mrd /= len(loss_mrd_real) + loss_disc = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd + + self.log("discriminator/total", loss_disc, prog_bar=True) + self.log("discriminator/multi_period_loss", loss_mp) + self.log("discriminator/multi_res_loss", loss_mrd) + + opt_disc.zero_grad() + self.manual_backward(loss_disc) + if self.hparams.clip_grad_norm is not None: + max_norm = self.hparams.clip_grad_norm + clip_grad_norm_(self.multiperioddisc.parameters(), max_norm=max_norm) + clip_grad_norm_(self.multiresddisc.parameters(), max_norm=max_norm) + opt_disc.step() + + audio_hat, kl_div = self(audio_input, **kwargs) + if self.train_discriminator: + _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc( + y=audio_input, + y_hat=audio_hat, + **kwargs, + ) + _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc( + y=audio_input, + y_hat=audio_hat, + **kwargs, + ) + loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp) + loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd) + loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp) + loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd) + loss_fm_mp = self.feat_matching_loss( + fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp + ) / len(fmap_rs_mp) + loss_fm_mrd = self.feat_matching_loss( + fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd + ) / len(fmap_rs_mrd) + + self.log("generator/multi_period_loss", loss_gen_mp) + self.log("generator/multi_res_loss", loss_gen_mrd) + self.log("generator/feature_matching_mp", loss_fm_mp) + self.log("generator/feature_matching_mrd", loss_fm_mrd) + + self.log("generator/kl_div", kl_div) + + mel_loss = self.melspec_loss(audio_hat, audio_input) + loss = ( + loss_gen_mp + + self.hparams.mrd_loss_coeff * loss_gen_mrd + + loss_fm_mp + + self.hparams.mrd_loss_coeff * loss_fm_mrd + + self.mel_loss_coeff * mel_loss + + self.hparams.kl_div_coeff * kl_div + ) + + self.log("generator/total_loss", loss, prog_bar=True) + self.log("mel_loss_coeff", self.mel_loss_coeff) + self.log("generator/mel_loss", mel_loss) + + opt_gen.zero_grad() + self.manual_backward(loss) + if self.hparams.clip_grad_norm is not None: + max_norm = self.hparams.clip_grad_norm + clip_grad_norm_(self.wavvae.parameters(), max_norm=max_norm) + opt_gen.step() + + def validation_step(self, batch, batch_idx, **kwargs): + audio_input = batch + audio_hat, _ = self(audio_input, **kwargs) + + if self.current_epoch % self.hparams.log_audio_every_n_epoch == 0: + wavs = [x.numpy(force=True) for x in audio_hat.unbind(0)] + if batch_idx == 0: + self._audios_to_log = wavs + if batch_idx < self.hparams.log_n_audio_batches: + self._audios_to_log += wavs + elif batch_idx == self.hparams.log_n_audio_batches: + self.logger.log_audio( + "audio", + self._audios_to_log, + step=self.global_step, + sample_rate=[ + self.wavvae.sampling_rate + for _ in range(len(self._audios_to_log)) + ], + ) + + mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1)) + total_loss = mel_loss + + return { + "val_loss": total_loss, + "mel_loss": mel_loss, + "audio_input": audio_input[0], + "audio_pred": audio_hat[0], + } + + @property + def global_step(self): + """ + Override global_step so that it returns the total number of batches processed + """ + return self.trainer.fit_loop.epoch_loop.total_batch_idx + + def on_train_batch_start(self, *args): + if self.global_step >= self.hparams.pretrain_mel_steps: + self.train_discriminator = True + else: + self.train_discriminator = False + + def on_train_batch_end(self, *args): + def mel_loss_coeff_decay(current_step, num_cycles=0.5): + max_steps = self.trainer.max_steps // 2 + if current_step < self.hparams.num_warmup_steps: + return 1.0 + progress = float(current_step - self.hparams.num_warmup_steps) / float( + max(1, max_steps - self.hparams.num_warmup_steps) + ) + return max( + 0.0, + 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), + ) + + if self.hparams.decay_mel_coeff: + self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay( + self.global_step + 1 + ) + +if __name__ == "__main__": + class WavVAECLI(LightningCLI): + def after_instantiate_classes(self): + hparams = self.model.hparams + kl_factor = "{:.1e}".format(hparams.kl_div_coeff) + latent_dim = hparams.config["latent_dim"] + frame_rate = self.model.wavvae.frame_rate + dataset_name = ( + Path(self.datamodule.train_config.filelist_path).with_suffix("").name + ) + + name = f"WavVAE_kl{kl_factor}_framerate{frame_rate}hz_latentdim{latent_dim}_dataset{dataset_name}" + + if self.trainer.logger: + logger = WandbLogger( + log_model=False, + project="codec", + name=name, + ) + model_checkpoint_cb = ModelCheckpoint( + monitor="generator/mel_loss", + dirpath="checkpoints/wavvae", + filename=name + "_epoch{epoch:02d}", + save_last=True, + ) + self.trainer.callbacks.append(model_checkpoint_cb) + + WavVAECLI( + save_config_kwargs={"overwrite": True}, + parser_kwargs={"parser_mode": "omegaconf"}, + ) diff --git a/pardi_speech.py b/pardi_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..66488ee926f4990844b6cbfde7bf34942cff4673 --- /dev/null +++ b/pardi_speech.py @@ -0,0 +1,208 @@ +import json +import string +from dataclasses import asdict, dataclass +from pathlib import Path + +import torch +from safetensors.torch import load_file +from torch.nn.utils.rnn import pad_sequence + +from codec.models import PatchVAE +from tts.model.cache_utils import FLACache +from tts.text_processor import BasicTextProcessor +from tts.tools import sequence_mask +from tts.tts import ARTTSModel + + +@dataclass +class VelocityHeadSamplingParams: + """ + Velocity head sampling parameters + + Attributes: + + cfg (float): CFG factor against unconditional prediction. + cfg_ref (float): CFG factor against a reference (to be used with a cache of size 2*batch_size and unfold). + temperature (float): scale factor of z0 ~ 𝒩(0,1) + num_steps (int): number of ODE steps + solver (str): parameter passed to NeuralODE + sensitivity (str): parameter passed to NeuralODE + """ + cfg: float = 1.3 + cfg_ref: float = 1.5 + temperature: float = 0.9 + num_steps: int = 13 + solver: str = "euler" + sensitivity: str = "adjoint" + + +@dataclass +class PatchVAESamplingParams: + """ + PatchVAE sampling parameters + + Attributes: + + cfg (float): CFG factor against unconditional prediction. + temperature (float): scale factor of z0 ~ 𝒩(0,1) + num_steps (int): number of ODE steps + solver (str): parameter passed to NeuralODE + sensitivity (str): parameter passed to NeuralODE + """ + + cfg: float = 2.0 + temperature: float = 1.0 + num_steps: int = 10 + solver: str = "euler" + sensitivity: str = "adjoint" + + +class PardiSpeech: + tts: ARTTSModel + patchvae: PatchVAE + text_processor: BasicTextProcessor + + def __init__( + self, + tts: ARTTSModel, + patchvae: PatchVAE, + text_processor: BasicTextProcessor, + ): + self.tts = tts + self.patchvae = patchvae + self.text_processor = text_processor + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + map_location: str = "cpu", + ): + if Path(pretrained_model_name_or_path).exists(): + path = pretrained_model_name_or_path + else: + from huggingface_hub import snapshot_download + + path = snapshot_download(pretrained_model_name_or_path) + + with open(Path(path) / "config.json", "r") as f: + config = json.load(f) + + artts_model, artts_config = ARTTSModel.instantiate_from_config(config) + state_dict = load_file( + Path(path) / "model.st", + device=map_location, + ) + artts_model.load_state_dict(state_dict, assign=True) + patchvae = PatchVAE.from_pretrained( + artts_config.patchvae_path, + map_location=map_location, + ) + text_processor = BasicTextProcessor( + str(Path(path) / "pretrained_tokenizer.json") + ) + + return cls(artts_model, patchvae, text_processor) + + def encode_reference(self, wav: torch.Tensor, sr: int): + import torchaudio + + new_freq = self.patchvae.wavvae.sampling_rate + wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=new_freq) + return self.patchvae.encode(wav) + + @property + def sampling_rate(self): + return self.patchvae.wavvae.sampling_rate + + def text_to_speech( + self, + text: str, + prefix: tuple[str, torch.Tensor] | None = None, + patchvae_sampling_params: PatchVAESamplingParams | None = None, + velocity_head_sampling_params: VelocityHeadSamplingParams | None = None, + prefix_separator: str = ". ", + max_seq_len: int = 600, + stop_threshold: float = 0.5, + cache: FLACache | None = None, + **kwargs, + ): + """ + Parameters + ---------- + text: str + The text to synthesize. + + prefix: tuple[str, torch.Tensor] | None + A pair (text, speech) consisting of a reference speech excerpt encoded (see encode_reference) and its corresponding text transcription. Synthesis is performed by continuing the prefix. If no prefix is given, the first frame is randomly sampled. + + patchvae_sampling_params: PatchVAESamplingParams + PatchVAE sampling parameters + + velocity_head_sampling_params: VelocityHeadSamplingParams + VelocityHead sampling parameters (AR sampling) + + prefix_separator: str + The separator that joins the prefix text to the target text. + + max_seq_len: int + The maximum number of latent to generate. + + stop_threshold: float + Threshold value at which AR prediction stops. + """ + + device = next(self.tts.parameters()).device + + if type(text) is str: + text = [text] + + if prefix is not None: + prefix_text, prefix_speech = prefix + prefix_text = prefix_text.strip().rstrip(string.punctuation) + if prefix_text != "": + text = [prefix_text + prefix_separator + t for t in text] + + prefix_speech = prefix_speech.repeat(len(text), 1, 1) + else: + _, audio_latent_sz = self.tts.audio_embd.weight.shape + prefix_speech = torch.randn(len(text), 1, audio_latent_sz, device=device) + + # if self.bos: + # text = "[BOS]" + text + + # if self.eos: + # text = text + "[EOS]" + text_ids = [torch.LongTensor(self.text_processor(x + "[EOS]")) for x in text] + text_pre_mask = sequence_mask(torch.tensor([x.shape[0] for x in text_ids])).to(device) + text_mask = text_pre_mask[:, None] * text_pre_mask[..., None] + crossatt_mask = text_pre_mask[:, None,None] + + text_ids = pad_sequence(text_ids, batch_first=True) + + if velocity_head_sampling_params is None: + velocity_head_sampling_params = VelocityHeadSamplingParams() + + if patchvae_sampling_params is None: + patchvae_sampling_params = PatchVAESamplingParams() + + + with torch.inference_mode(): + _, predictions = self.tts.generate( + text_ids.to(device), + text_mask=text_mask, + crossatt_mask=crossatt_mask, + prefix=prefix_speech.to(device), + max_seq_len=max_seq_len, + sampling_params=asdict(velocity_head_sampling_params), + stop_threshold=stop_threshold, + cache=cache, + device=device, + **kwargs, + ) + wavs = [self.patchvae.decode( + p, + **asdict(patchvae_sampling_params), + ) for p in predictions] + + return wavs, predictions diff --git a/tts/__pycache__/groupdataset.cpython-312.pyc b/tts/__pycache__/groupdataset.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2c40da29cddbd1329fc33fa0051854b863577ea Binary files /dev/null and b/tts/__pycache__/groupdataset.cpython-312.pyc differ diff --git a/tts/__pycache__/text_processor.cpython-312.pyc b/tts/__pycache__/text_processor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15abe599201f2508fe8648fbbaf78ed67b80d501 Binary files /dev/null and b/tts/__pycache__/text_processor.cpython-312.pyc differ diff --git a/tts/__pycache__/tools.cpython-312.pyc b/tts/__pycache__/tools.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..953486f643e9b0b17bc7711ea8cdf9d8d5b7ea16 Binary files /dev/null and b/tts/__pycache__/tools.cpython-312.pyc differ diff --git a/tts/__pycache__/train_tts.cpython-312.pyc b/tts/__pycache__/train_tts.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1200571bbdc82d9a0246dbaa6dbadb1178ce38e7 Binary files /dev/null and b/tts/__pycache__/train_tts.cpython-312.pyc differ diff --git a/tts/__pycache__/tts.cpython-312.pyc b/tts/__pycache__/tts.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2a5d73b089ca43e468ed60c696d8ffca427fa4e Binary files /dev/null and b/tts/__pycache__/tts.cpython-312.pyc differ diff --git a/tts/groupdataset.py b/tts/groupdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c03108f382fb2f732ad473407e91db3e8bf646 --- /dev/null +++ b/tts/groupdataset.py @@ -0,0 +1,647 @@ +import math +import random +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from functools import reduce +from itertools import accumulate +from random import choices +from typing import List, Optional, Sequence, Tuple + +import pytorch_lightning as ptl +import torch +from datasets import (DatasetDict, concatenate_datasets, load_dataset, + load_from_disk) +from einops import rearrange +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import (BatchSampler, DataLoader, Sampler, + SubsetRandomSampler) +from transformers import PreTrainedTokenizerFast + +from tts.tools import (audio_to_text_partial_neighbor_mask, packmask_2d, + pad_2d_sequence, sequence_mask) + + +class BucketSampler(Sampler[List[int]]): + def __init__( + self, + buckets: List[List[int]], + batch_sizes: List[int] | int, + bucket_sampling_weights: List[Tuple[float]] = None, + drop_last: bool = True, + distributed: bool = True, # TODO - implement not distributed as well + sample_bucket: Optional[int] = None, + seed: int = 123, + epoch_seed: bool = True, + ): + if type(batch_sizes) is int: + batch_sizes = [batch_sizes] * len(buckets) + else: + assert len(buckets) == len(batch_sizes) + + if bucket_sampling_weights is not None: + assert len(bucket_sampling_weights) == len(batch_sizes) + self.bucket_sampling_weights = bucket_sampling_weights + self.num_replicas = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + self.buckets = [ + b[self.rank : len(b) - len(b) % self.num_replicas : self.num_replicas] + for b in buckets + ] + self.num_samples = [len(b) // self.num_replicas for b in buckets] + self.batch_sizes = batch_sizes + self.total_sizes = [ + ns // bs for ns, bs in zip(self.num_samples, self.batch_sizes) + ] + self.drop_last = drop_last + self.seed = seed + self.epoch = 0 + self.sample_bucket = sample_bucket + self.epoch_seed = epoch_seed + self.batch_size = batch_sizes[0] + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def __len__(self): + return sum(self.total_sizes) + + def __iter__(self): + generator = torch.Generator() + generator.manual_seed(self.seed + self.epoch * self.epoch_seed + self.rank) + pool = [ + BatchSampler( + SubsetRandomSampler(b, generator=generator), + bs, + drop_last=self.drop_last, + ) + for b, bs in zip(self.buckets, self.batch_sizes) + ] + pool = [iter(b) for b in pool] + weights = ( + [w for w in self.bucket_sampling_weights] + if self.bucket_sampling_weights is not None + else None + ) + while pool: # sample until all buckets are done + idx, bucket = choices(list(enumerate(pool)), weights=weights)[0] + try: + batch = next(bucket) + yield batch + except StopIteration: + pool.pop(idx) # if bucket is done, throw it + if weights is not None: + weights.pop(idx) + + +class DatasetFactory(ABC): + @abstractmethod + def build(self): + pass + + +class HiFiTTS2_AudioLatent(DatasetFactory): + def __init__( + self, + path: str | list[str] = "hifitts2_vae8_dataset", + duration_column: str = "audio_duration", + duration_path: str | None = None, + expresso_path: str | None = None, + min_dur: float = 3.0, + max_dur: float = 20.1, + framerate: float = 25.0, + ): + self.min_dur = min_dur + self.max_dur = max_dur + self.path = path + self.duration_column = duration_column + self.duration_path = duration_path + self.framerate = framerate + self.expresso_path = expresso_path + + def build(self): + if type(self.path) is str: + self.path = [self.path] + datasets = [load_from_disk(x) for x in self.path] + dataset = concatenate_datasets(datasets).with_format("torch") + if self.duration_path is not None: + duration_dataset = load_from_disk(self.duration_path) + dataset = concatenate_datasets( + [dataset, duration_dataset], axis=1 + ).with_format("torch") + dataset = dataset.filter( + lambda dur: dur > self.min_dur and dur < self.max_dur, + input_columns=self.duration_column, + ) + + dataset = dataset.rename_column(self.duration_column, "audio_duration") + # dataset = dataset.map( + # lambda x: {"audio_duration": x.shape[1] / self.framerate}, + # input_columns="audio_latent", + # ).filter( + # lambda dur: dur > self.min_dur and dur < self.max_dur, + # input_columns="audio_duration", + # ) + if self.expresso_path is not None: + expresso_dataset = load_from_disk(self.expresso_path).with_format("torch") + + dataset = dataset.sort("audio_duration") + return DatasetDict({"train": dataset}) + + +@dataclass +class SegmentsCollateArgs: + abs_style_intensity: bool = False + merge_endpoints: bool = True + block_crossatt_mask: bool = True + alternate_crossatt_pos: bool = False + block_crossatt_past_tokens: int = 0 + block_crossatt_future_tokens: int = 0 + eos: bool = True + bos: bool = True + + +@dataclass +class CollateArgs: + abs_style_intensity: bool = False + random_text_segment: bool = False + eos: bool = True + bos: bool = True + num_stop_tokens: int = 1 + +def random_log_breakpoints( + seq: Sequence, a: int, b: int, gap: bool = False +) -> List[int]: + """ + Generate random breakpoints in a sequence where the gap X between + successive breakpoints satisfies log2(X) ~ Uniform[log2(a), log2(b)]. + Gaps are then rounded to the nearest integer in [a, b]. + + Parameters + ---------- + seq : Sequence + The input sequence in which to place breakpoints. + a : int + Minimum gap (>= 1). + b : int + Maximum gap (>= a). + + Returns + ------- + List[int] + Sorted list of breakpoint indices (0 < idx < len(seq)). + """ + if a < 1 or b < a: + raise ValueError("Require 1 <= a <= b") + n = len(seq) + breakpoints: List[int] = [] + pos = 0 + + while True: + # sample U ~ Uniform(log2(a), log2(b)) + u = random.uniform(math.log2(a), math.log2(b)) + # map back to X = 2^U, then round to nearest integer + x = 2**u + gap = int(math.floor(x + 0.5)) + + # enforce integer bounds exactly + gap = max(a, min(b, gap)) + + pos += gap + if pos >= n: + if gap: + breakpoints.append(n - sum(breakpoints)) + break + if gap: + breakpoints.append(gap) + else: + breakpoints.append(pos) + + return breakpoints + + + +def standalone_collate_latent( + batch, + tokenizer, + abs_style_intensity: bool = False, + random_text_segment: bool = False, + bos: bool = True, + eos: bool = True, + num_stop_tokens: int = 1, +): + audio_latent, text = zip(*[(x["audio_latent"], x["text"]) for x in batch]) + audio_latent = [x.squeeze() for x in audio_latent] + text_pp = [] + for t in text: + if bos: + t = "[BOS]" + t + if eos: + t = t + "[EOS]" + text_pp.append(t) + text_token = [torch.LongTensor(tokenizer.encode(x)) for x in text_pp] + xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent)) + + stop_token = [] + text_stop_token = [] + for x, y in zip(xlen, ylen): + tst = torch.zeros(x) + st = torch.zeros(y) + st_idx = random.randint(1, num_stop_tokens) + st[-1] = st_idx + tst[-1] = st_idx + stop_token.append(st) + text_stop_token.append(tst) + stop_token = pad_sequence(stop_token, batch_first=True).long() + text_stop_token = pad_sequence(text_stop_token, batch_first=True).long() + + x_mask, y_mask = map( + lambda x: sequence_mask(x, device="cpu"), + (torch.tensor(xlen), torch.tensor(ylen)), + ) + + text_rel_pos = None + + if random_text_segment: + breakpoints = [random_log_breakpoints(t, 32, 256, gap=True) for t in text_token] + encoder_mask = pad_2d_sequence([packmask_2d(b, b) for b in breakpoints]) + text_rel_pos = [torch.cat([torch.arange(bb) for bb in b]) for b in breakpoints] + text_rel_pos = pad_sequence(text_rel_pos, batch_first=True) + else: + encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2) + crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2) + + audio_latent, text_token = map( + lambda x: pad_sequence(x, batch_first=True, padding_value=0.0), + (audio_latent, text_token), + ) + + if abs_style_intensity: + abs_style_intensity = [x["abs_style_intensity"] for x in batch] + abs_style_intensity = [ + torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity + ] + abs_style_intensity = torch.stack(abs_style_intensity) + else: + abs_style_intensity = None + + return { + "text_token": text_token, + "audio_token": audio_latent, + "crossatt_mask": crossatt_mask, + "encoder_mask": encoder_mask, + "y_mask": y_mask, + "stop_token": stop_token, + "text_stop_token": text_stop_token, + "x_len": xlen, + "y_len": ylen, + "abs_style_intensity": abs_style_intensity, + "text_rel_pos": text_rel_pos, + } + + +def standalone_collate_latent_segments( + batch, + tokenizer, + abs_style_intensity: bool = False, + merge_endpoints: bool = True, + block_crossatt_mask: bool = True, + block_crossatt_past_tokens: int = 0, + block_crossatt_future_tokens: int = 0, + alternate_crossatt_pos: bool = False, + alternate_crossatt_shift: int = 1000, + eos: bool = True, + bos: bool = True, +): + audio_latent, text, token_duration = zip( + *[(x["audio_latent"], x["text"], x["token_duration"]) for x in batch] + ) + text_pp = [] + for t in text: + if bos: + t = "[BOS]" + t + if eos: + t = t + "[EOS]" + text_pp.append(t) + + if merge_endpoints: + tokens = [tokenizer.encode(x) for x in text] + new_td = [] + for td in token_duration: + begin, end = td[0], td[-1] + tdd = td[1:-1] + tdd[0] += begin + tdd[-1] += end + new_td.append(tdd) + token_duration = new_td + else: + tokens = [tokenizer.encode(x) for x in text_pp] + segments = [ + random_segments_from_text_and_durations(t, td.tolist()) + for t, td in zip(tokens, token_duration) + ] + bos, eos = map(tokenizer.encode, ("[BOS]", "[EOS]")) + audio_segments = [] + text_segments = [] + audio_segments_len = [] + text_segments_len = [] + + for aud, seg in zip(audio_latent, segments): + tt, at, tt_l, at_l = [], [], [], [] + + for i, s in enumerate(seg): + ttoken = s["text_token"] + if bos: + ttoken = bos + ttoken + if eos: + ttoken = ttoken + eos + tt.append(ttoken) + a_s = aud[:, s["start"] : s["end"]] + at.append(a_s) + at_l.append(a_s.shape[1]) + tt_l.append(len(ttoken)) + + audio_segments.append(at) + text_segments.append(tt) + audio_segments_len.append(at_l) + text_segments_len.append(tt_l) + + text_token = [torch.LongTensor(reduce(list.__add__, x)) for x in text_segments] + audio_latent = [torch.cat(a_ss, dim=1).squeeze(0) for a_ss in audio_segments] + + xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent)) + x_mask, y_mask = map( + lambda x: sequence_mask(x, device="cpu"), + (torch.tensor(xlen), torch.tensor(ylen)), + ) + + audio_latent, text_token = map( + lambda x: pad_sequence(x, batch_first=True, padding_value=0), + (audio_latent, text_token), + ) + + encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2) + if block_crossatt_mask: + crossatt_mask = [ + audio_to_text_partial_neighbor_mask( + x, + y, + past_tokens=block_crossatt_past_tokens, + future_tokens=block_crossatt_future_tokens, + ) + for x, y in zip(text_segments_len, audio_segments_len) + ] + crossatt_mask = pad_2d_sequence(crossatt_mask) + pad_mask = rearrange(torch.arange(max(ylen)), "n -> 1 n 1") >= rearrange( + torch.tensor(ylen), "n -> n 1 1" + ) + else: + crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2) + + text_rel_pos = pad_sequence( + [torch.cat([torch.arange(x) for x in tsl]) for tsl in text_segments_len], + batch_first=True, + ) + + crossatt_rel_pos = None + if alternate_crossatt_pos: + crossatt_rel_pos = [] + for tsl in text_segments_len: + rel_pos = [] + random_shift = int(random.random() < 0.5) + for i, x in enumerate(tsl): + rel_pos.append( + torch.arange(x) + + ((random_shift + i) % 2) * alternate_crossatt_shift + ) + crossatt_rel_pos.append(torch.cat(rel_pos)) + crossatt_rel_pos = pad_sequence(crossatt_rel_pos, batch_first=True) + + audio_rel_pos = pad_sequence( + [torch.cat([torch.arange(x) for x in asl]) for asl in audio_segments_len], + batch_first=True, + ) + + stop_token = [] + for asl in audio_segments_len: + sts = [] + for x in asl: + st = torch.zeros(x) + st[-1] = 1 + sts.append(st) + stop_token.append(torch.cat(sts)) + stop_token = pad_sequence(stop_token, batch_first=True).int() + + text_stop_token = [] + for asl in text_segments_len: + sts = [] + for x in asl: + st = torch.zeros(x) + st[-1] = 1 + sts.append(st) + text_stop_token.append(torch.cat(sts)) + text_stop_token = pad_sequence(text_stop_token, batch_first=True).int() + + if abs_style_intensity: + abs_style_intensity = [x["abs_style_intensity"] for x in batch] + abs_style_intensity = [ + torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity + ] + abs_style_intensity = torch.stack(abs_style_intensity) + else: + abs_style_intensity = None + + return { + "text_token": text_token, + "audio_token": audio_latent, + "crossatt_mask": crossatt_mask, + "encoder_mask": encoder_mask, + "y_mask": y_mask, + "stop_token": stop_token, + "text_stop_token": text_stop_token, + "x_mask": x_mask, + "x_len": xlen, + "y_len": ylen, + "abs_style_intensity": abs_style_intensity, + "text_rel_pos": text_rel_pos, + "crossatt_rel_pos": crossatt_rel_pos, + "audio_rel_pos": audio_rel_pos, + "segments": segments, + } + + + +def random_segments_from_text_and_durations( + text, + dur, + low_bnd: int = 8, + up_bnd: int = 384, +): + b = random_log_breakpoints(text, low_bnd, up_bnd) + bounds = [0] + b + [len(text)] + segs, durs = [], [] + for a, b in zip(bounds[:-1], bounds[1:]): + segs.append(text[a:b]) + durs.append(sum(dur[a:b])) + bounds = [0] + list(accumulate(durs, int.__add__)) + segs_dicts = [] + for t, s, e in zip(segs, bounds[:-1], bounds[1:]): + segs_dicts.append( + { + "start": s, + "end": e, + "text_token": t, + } + ) + segs_dicts[-1]["end"] += 1 + return segs_dicts + + +class LinaDataModule(ptl.LightningDataModule): + def __init__( + self, + path: str | DatasetFactory, + quant_layer: list[int], + train_batch_size: int = 8, + token_by_batch: int | None = None, + n_buckets=5, + codec_rate_hz: int = 75, + num_workers: int = 8, + test_size: int = 2000, + val_batch_size: int = 8, + seed: int = 123, + train_test_seed: int = 123, + segments: bool = False, + segments_args: SegmentsCollateArgs = field( + default_factory=lambda: SegmentsCollateArgs() + ), + collate_args: CollateArgs = field(default_factory=lambda: CollateArgs()), + block_mask_segments: bool = False, + tokenizer_file=None, + trail_end_frame: int | None = None, + split="train", + add_columns: str | list[str] | None = None, + add_text_tokens: list[str] | None = None, + type: str = "latent", + ): + super().__init__() + + self.path = path + self.codec_rate_hz = codec_rate_hz + self.num_workers = num_workers + self.quant_layer = quant_layer + self.seed = seed + self.segments = segments + self.segments_args = segments_args + self.collate_args = collate_args + self.train_test_seed = train_test_seed + self.test_size = test_size + self.val_batch_size = val_batch_size + self.train_batch_size = train_batch_size + self.split = split + self.trail_end_frame = trail_end_frame + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) + if add_text_tokens: + self.tokenizer.add_tokens(add_text_tokens) + self.add_columns = add_columns + self.n_buckets = n_buckets + self.token_by_batch = token_by_batch + self.type = type + + def setup(self, stage): + if isinstance(self.path, DatasetFactory): + self.dataset = self.path.build() + else: + self.dataset = load_dataset(self.path) + + split = self.split + columns = [ + "audio_latent" if self.type == "latent" else "audio_token", + "text", + "audio_duration", + ] + if self.add_columns is not None: + if type(self.add_columns) is str: + self.add_columns = [self.add_columns] + columns += self.add_columns + + if self.segments: + columns += ["token_duration"] + self.collate_fn = lambda x: segments_collate(x, self.tokenizer) + else: + self.collate_fn = lambda x: standalone_collate( + x, self.tokenizer, abs_style_intensity="abs_style_intensity" in columns + ) + self.dataset = ( + self.dataset[split] + .train_test_split(test_size=self.test_size, seed=self.train_test_seed) + .select_columns(columns) + ) + + if self.type == "latent": + if self.segments: + self.collate_fn = lambda x: standalone_collate_latent_segments( + x, + self.tokenizer, + **self.segments_args, + ) + else: + self.collate_fn = lambda x: standalone_collate_latent( + x, + self.tokenizer, + **self.collate_args, + ) + + def get_buckets_by_quantile(duration, n_quantile, is_sorted=False): + if is_sorted: + size = len(duration) + bucket_size = size // n_quantile + buckets = [ + list(range(i, min(i + bucket_size, size))) + for i in range(0, size, bucket_size) + ] + + else: + idxdur = list(enumerate(duration)) + idxdur.sort(key=lambda x: x[1]) + idx, dur = zip(*idxdur) + bucket_size = len(idx) // n_quantile + buckets = [list(x) for x in zip(*[iter(idx)] * bucket_size)] + return buckets + + if self.token_by_batch is not None: + train_buckets = get_buckets_by_quantile( + self.dataset["train"]["audio_duration"], self.n_buckets + ) + max_audio_durations = [ + self.dataset["train"]["audio_duration"][x[-1]] for x in train_buckets + ] + + batch_sizes = [ + int(self.token_by_batch // (self.codec_rate_hz * ad)) + for ad in max_audio_durations + ] + self.train_batch_sampler = BucketSampler(train_buckets, batch_sizes) + + def train_dataloader(self): + if self.token_by_batch is not None: + return DataLoader( + self.dataset["train"].with_format("torch"), + num_workers=self.num_workers, + collate_fn=self.collate_fn, + batch_sampler=self.train_batch_sampler, + ) + else: + return DataLoader( + self.dataset["train"].with_format("torch"), + num_workers=self.num_workers, + batch_size=self.train_batch_size, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.dataset["test"].with_format("torch"), + batch_size=self.val_batch_size, + num_workers=0, + collate_fn=self.collate_fn, + ) diff --git a/tts/layers/__init__.py b/tts/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tts/layers/__pycache__/__init__.cpython-312.pyc b/tts/layers/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6997e2139ae796bce914b3e9c64e2f8e5e1c853 Binary files /dev/null and b/tts/layers/__pycache__/__init__.cpython-312.pyc differ diff --git a/tts/layers/__pycache__/attention.cpython-312.pyc b/tts/layers/__pycache__/attention.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f67df057340e57c0a3b11bdd012c5d92b7d284e4 Binary files /dev/null and b/tts/layers/__pycache__/attention.cpython-312.pyc differ diff --git a/tts/layers/__pycache__/conv.cpython-312.pyc b/tts/layers/__pycache__/conv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9176fc8967ee994b3e71cb4cc16f505a4db21d37 Binary files /dev/null and b/tts/layers/__pycache__/conv.cpython-312.pyc differ diff --git a/tts/layers/__pycache__/ffn.cpython-312.pyc b/tts/layers/__pycache__/ffn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e122609782808acc10f04294cf8fd069fb9ef186 Binary files /dev/null and b/tts/layers/__pycache__/ffn.cpython-312.pyc differ diff --git a/tts/layers/attention.py b/tts/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ac758ba75c62bdf923d2450988ee39083ed2699a --- /dev/null +++ b/tts/layers/attention.py @@ -0,0 +1,477 @@ +import math +import os +import time +from typing import Literal + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from fla.models.utils import Cache +from torch import nn +from transformers.cache_utils import Cache + + +def apply_causal_sliding_window(mask: torch.Tensor, window_size: int) -> torch.Tensor: + B, H, Q, KV = mask.shape + device = mask.device + + q_idx = torch.arange(Q, device=device).unsqueeze(1) # (Q, 1) + k_idx = torch.arange(KV, device=device).unsqueeze(0) # (1, KV) + + lower_bound = q_idx - (window_size - 1) # (Q, 1), may be negative + allowed_2d = (k_idx <= q_idx) & (k_idx >= lower_bound) # (Q, KV), dtype=torch.bool + + allowed_4d = allowed_2d.unsqueeze(0).unsqueeze(0).expand(B, H, Q, KV) + + orig_dtype = mask.dtype + if mask.dtype != torch.bool: + mask_bool = mask.to(torch.bool) + else: + mask_bool = mask + + new_mask = mask_bool & allowed_4d + + if orig_dtype != torch.bool: + return new_mask.to(orig_dtype) + else: + return new_mask + + +def precompute_freqs_cis_( + t: torch.Tensor, + n_elem: int, + base: float = 10000, +) -> torch.Tensor: + freqs = 1.0 / ( + base + ** ( + torch.arange(0, n_elem, 2, device=t.device)[: (n_elem // 2)].float() + / n_elem + ) + ) + freqs = torch.outer(t, freqs) + cache = repeat(freqs, "... d -> ... (d 2)") + return cache + + +import torch +from einops import repeat + + +def precompute_freqs_cis( + t: torch.Tensor, # shape: (B, T) or (T,) + n_elem: int, + base: float = 10000, +) -> torch.Tensor: + """ + Batched version of precompute_freqs_cis. + + Args: + t: torch.Tensor, shape (B, T) or (T,) + Timesteps to compute frequencies for. + n_elem: int + Embedding dimension (must be even). + base: float + Base for frequency computation (default: 10000). + + Returns: + cache: torch.Tensor, shape (B, T, n_elem) if batched, + (T, n_elem) if unbatched. + """ + if t.dim() == 1: # unbatched + t = t.unsqueeze(0) # (1, T) + + B, T = t.shape + device = t.device + + # frequencies (half dimension, then expand back) + freqs = 1.0 / ( + base + ** (torch.arange(0, n_elem, 2, device=device)[: (n_elem // 2)].float() / n_elem) + ) # shape: (n_elem // 2,) + + # outer product for each batch + # (B, T, n_elem//2) + freqs = torch.einsum("bt,d->btd", t, freqs) + + # duplicate last dim to interleave sin/cos pairs + # (B, T, n_elem) + cache = repeat(freqs, "... d -> ... (d 2)") + + # if cache.shape[0] == 1: # if originally unbatched + # cache = cache.squeeze(0) # (T, n_elem) + + return cache + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + out = x * freqs_cis.cos() + rotate_half(x) * freqs_cis.sin() + return out + + +def scaled_dot_product_attention(query, key, value, mask=None): + scale_factor = 1 / math.sqrt(query.size(-1)) + attn_weight = query @ key.transpose(-2, -1) * scale_factor + if mask is not None: + attn_weight.masked_fill_(~mask, -torch.finfo(attn_weight.dtype).max) + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value, attn_weight + + +class SelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + layer_idx: int, + is_causal: bool = False, + sliding_window: int | None = None, + ): + super().__init__() + self.qkv = nn.Linear(dim, 3 * dim) + assert dim % num_heads == 0 + self.heads = num_heads + self.is_causal = is_causal + self.layer_idx = layer_idx + self.output_proj = nn.Linear(dim, dim) + self.sliding_window = sliding_window + if self.sliding_window is not None: + self.is_causal = False + + def forward( + self, + x, + freqs: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + cache: Cache | None = None, + ): + B, T, D = x.shape + q, k, v = self.qkv(x).chunk(3, dim=-1) + q, k, v = map( + lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v) + ) + + if freqs is not None: + q = apply_rotary_emb(q, freqs) + k = apply_rotary_emb(k, freqs) + if cache is not None: + cache.update(attn_state=(k, v), layer_idx=self.layer_idx, offset=T) + k, v = cache[self.layer_idx]["attn_state"] + + if self.sliding_window is not None: + mask = torch.ones(B, 1, T, T, device=x.device) + mask = apply_causal_sliding_window(mask, self.sliding_window) + + + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, is_causal=self.is_causal and T > 1 + ) + y = rearrange(y, "b h n d -> b n (h d)") + y = self.output_proj(y) + return y + + +class CrossAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + layer_idx: int | None = None, + dropout: float = 0.1, + ): + super().__init__() + assert dim % num_heads == 0 + self.pre_norm_q = nn.LayerNorm(dim) + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.layer_idx = layer_idx + self.heads = num_heads + self.dropout_att = dropout + + def _prepare_kv(self, text_hidden_states: torch.Tensor): + v = self.ln_v(self.v(text_hidden_states)) + k = self.ln_k(self.k(text_hidden_states)) + + def _query(self, x): + return self.q(self.pre_norm_q(q)) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor | None = None, + v: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + output_attention: bool = False, + cache: Cache | None = None, + **kwargs, + ): + if v is None: + v = k + q = self.q(self.pre_norm_q(q)) + if cache is not None: + if cache[self.layer_idx] is not None: + ca_state = cache[self.layer_idx]["crossatt_state"] + if ca_state is not None: + k, v = ca_state + else: + v = self.v(v) + k = self.k(k) + cache.update(crossatt_state=(k, v), layer_idx=self.layer_idx) + else: + v = self.v(v) + k = self.k(k) + q, k, v = map( + lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v) + ) + + if mask is not None: + if mask.ndim == 3: + mask = mask[:, None] + + # if not self.training: + if not self.training: + x, att = scaled_dot_product_attention(q, k, v, mask=mask) + else: + x = nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=self.dropout_att + ) + att = None + x = rearrange(x, "b h n d -> b n (h d)") + if att is not None: + if cache is not None: + cache.update(crossatt_weights=att, layer_idx=self.layer_idx) + else: + self.att = att + return x + + +class ConvPos(nn.Module): + def __init__(self, dim, max_seq_len=1000, kernel_size=7, n_parallel_codebook=2): + super().__init__() + self.embed = nn.Embedding(max_seq_len * n_parallel_codebook, dim) + self.dw_conv = nn.Conv1d(dim, dim, kernel_size, groups=dim, padding="same") + self.max_seq_len = max_seq_len + self.n_parallel_codebook = n_parallel_codebook + + def forward(self, x, left_shift=0, random_shift=False): + # left_pad = 31 if left_shift > 0 else 0 + # x = torch.cat((torch.arange(left_shift - left_pad, left_shift).to(x).unsqueeze(0),x, torch.arange(31).to(x).unsqueeze(0)), dim=1).clamp_min_(0) + if random_shift: + bias = torch.randint( + 0, + self.n_parallel_codebook, + (x.shape[0],), + device=x.device, + ) + x = x + bias * self.max_seq_len + y = self.embed(x) + y = rearrange(y, "b n c -> b c n") + y = self.dw_conv(y) + y = rearrange(y, "b c n -> b n c") # [:,left_pad:-31] + return y + + +class SinPos(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + exp = torch.arange(self.dim // 2, device=x.device) + exp = 2 * exp / (self.dim) + exp = rearrange(exp, "e -> 1 1 e") + x = rearrange(x, "b p -> b p 1") + pos = x * torch.pow(10000, -exp) + pos = torch.cat((pos, pos + math.pi / 2), dim=2) + pos = torch.sin(pos) + return pos + + +class BlindCrossAttention(nn.Module): + def __init__( + self, + q_dim, + k_dim, + att_dim, + pos_net, + dropout=0.1, + pos_dim=64, + pos_type="sinusoidal", + layer_idx: int | None = None, + ): + super().__init__() + self.q = nn.Linear(q_dim, att_dim) + self.k = nn.Linear(k_dim, att_dim) + self.v = nn.Linear(k_dim, att_dim) + self.pos_net = pos_net + if pos_type == "sinusoidal": + self.pos_embed = SinPos(pos_dim) + elif pos_type == "convolutional": + self.pos_embed = ConvPos(pos_dim) + self.ln_q = nn.LayerNorm(att_dim) + self.ln_k = nn.LayerNorm(att_dim) + self.ln_v = nn.LayerNorm(att_dim) + self.dropout_att = nn.Dropout(dropout) + self.layer_idx = layer_idx + + def _prepare_kv(self, text_hidden_states: torch.Tensor): + v = self.ln_v(self.v(text_hidden_states)) + k = self.ln_k(self.k(text_hidden_states)) + b, h, j, d = k.shape + pos = torch.arange(j, device=k.device).unsqueeze(0) + pos_emb = self.pos_embed(pos) + return {"k": k, "v": v, "pos_emb": pos_emb} + + def _query(self, x): + return self.ln_q(self.q(x)) + + def forward( + self, + q, + k, + kv_cached=None, + mask=None, + time_step=None, + pos=None, + left_shift=0, + past_key_values=None, + cache=None, + **kwargs, + ): + q = self.ln_q(self.q(q)) + # if kv_cached is None: + # v = self.ln_v(self.v(k)) + # k = self.ln_k(self.k(k)) + # else: + # k, v = kv_cached + + if mask is not None: + mask = mask.unsqueeze(1) + + if cache is not None: + if cache[self.layer_idx] is not None: + ca_state = cache[self.layer_idx]["crossatt_state"] + if ca_state is not None: + k, v, pos_emb = ca_state + else: + # v = self.v(v) + # k = self.k(k) + v = self.ln_v(self.v(k)) + k = self.ln_k(self.k(k)) + pos = torch.arange(k.shape[-2], device=k.device).unsqueeze(0) + pos_emb = self.pos_embed(pos, left_shift=left_shift) + cache.update( + crossatt_state=(k, v, pos_emb), layer_idx=self.layer_idx + ) + else: + v = self.ln_v(self.v(k)) + k = self.ln_k(self.k(k)) + if pos is None: + pos = torch.arange(k.shape[-2], device=k.device).unsqueeze(0) + pos_emb = self.pos_embed(pos, left_shift=left_shift) + + q, k, v = map(lambda x: rearrange(x, "b n d -> b 1 n d"), (q, k, v)) + b, h, j, d = k.shape + if self.training: + sdpa = lambda q, k, pos: ( + nn.functional.scaled_dot_product_attention( + q, k, pos, attn_mask=mask, dropout_p=self.dropout_att.p + ), + None, + ) + else: + sdpa = lambda q, k, pos: scaled_dot_product_attention(q, k, pos, mask=mask) + + x, att1 = sdpa(q, k, pos_emb.unsqueeze(1)) + x = rearrange(x, "b 1 n d -> b n d") + x = self.pos_net(x, cache=cache) + x = rearrange(x, "b n d -> b 1 n d") + pos_emb = rearrange(pos_emb, "b n d -> b 1 n d") + x, att2 = sdpa(x, pos_emb, v) + x = rearrange(x, "b 1 n d -> b n d") + + self.att1 = att1 + self.att2 = att2 + if att2 is not None: + if cache is not None: + cache.update( + crossatt_weights=torch.cat((att1, att2), dim=1), + layer_idx=self.layer_idx, + ) + return x + + +class ListenReadCrossAttention(nn.Module): + def __init__( + self, + q_dim: int, + k_dim: int, + att_dim: int, + crossatt_type: Literal["listen", "read"], + num_heads: int = 1, + dropout: float = 0.1, + layer_idx: int | None = None, + ): + super().__init__() + self.q = nn.Linear(q_dim, att_dim) + self.k = nn.Linear(k_dim, att_dim) + self.ln_q = nn.LayerNorm(att_dim) + self.ln_k = nn.LayerNorm(att_dim) + self.dropout_att = nn.Dropout(dropout) + self.crossatt_type = crossatt_type + self.layer_idx = layer_idx + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + text_freqs: torch.Tensor, + mask: torch.Tensor | None = None, + past_key_values=None, + cache=None, + **kwargs, + ): + q = self.ln_q(self.q(q)) + k = self.ln_k(self.k(k)) + + if mask is not None: + mask = mask.unsqueeze(1) + + q, k = map(lambda x: rearrange(x, "b n d -> b 1 n d"), (q, k)) + if self.training: + sdpa = lambda q, k, pos: ( + nn.functional.scaled_dot_product_attention( + q, k, pos, attn_mask=mask, dropout_p=self.dropout_att.p + ), + None, + ) + else: + sdpa = lambda q, k, pos: scaled_dot_product_attention(q, k, pos, mask=mask) + + + text_freqs = rearrange(text_freqs, "b n d -> b 1 n d") + if self.crossatt_type == "listen": + x, att = sdpa(q, k, text_freqs) + elif self.crossatt_type == "read": + x, att = sdpa(q, text_freqs, k) + else: + raise ValueError + + x = rearrange(x, "b 1 n d -> b n d") + if att is not None: + if cache is not None: + cache.update( + crossatt_weights=att, + layer_idx=self.layer_idx, + ) + + self.att = att + return x diff --git a/tts/layers/conv.py b/tts/layers/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..86ae2783a7b404dd56d2722a862a3ed549bdddc4 --- /dev/null +++ b/tts/layers/conv.py @@ -0,0 +1,56 @@ +import torch +from torch import nn + + +class ConvNeXtBlock(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int | None = None, + layer_scale_init_value: float = 0.0, + elementwise_affine_ln: bool = True, + kernel_size: int = 5, + ): + super().__init__() + intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim + ) # depthwise conv + self.norm = nn.LayerNorm( + dim, eps=1e-6, elementwise_affine=elementwise_affine_ln + ) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward( + self, + x: torch.Tensor, + scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None, + gate: torch.Tensor | None = None, + ) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + x = self.norm(x) + if scale_shift is not None: + scale, shift = scale_shift + x = x * scale[:, None] + shift[:, None] + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + if gate is not None: + x = gate[:, None] * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x diff --git a/tts/layers/ffn.py b/tts/layers/ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5133855ba99e3b0be9257c3383edbceef49514fd --- /dev/null +++ b/tts/layers/ffn.py @@ -0,0 +1,68 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class SwiGLU(nn.Module): + def __init__(self, d_model: int, ffn_expansion_factor: int = 4): + super().__init__() + self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2) + self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model) + + def forward(self, x): + gate, x = self.p_in(x).chunk(2, dim=-1) + return self.p_out(nn.functional.silu(gate) * x) + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class GaussianFourierTimeEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.weight = nn.Parameter(torch.randn(dim), requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x[:, None] * self.weight[None, :] * 2 * torch.pi + x = torch.cat((torch.sin(x), torch.cos(x)), dim=1) + return x + + +class AdaLNFinalLayer(nn.Module): + def __init__(self, hidden_dim, feature_dim): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_dim, feature_dim, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_dim, 2 * hidden_dim, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class AdaLNMLP(nn.Module): + def __init__(self, hidden_dim): + super().__init__() + self.hidden_dim = hidden_dim + + self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False) + self.mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim, bias=True), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim, bias=True), + ) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_dim, 3 * hidden_dim, bias=True) + ) + + def forward(self, x, y): + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h diff --git a/tts/model/__init__.py b/tts/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..29c6330d7c704d41d9874637cd36ca123a8e51db --- /dev/null +++ b/tts/model/__init__.py @@ -0,0 +1,2 @@ +from .simple_gla import SimpleGLADecoder +from .transformer import TransformerDecoder, TransformerEncoder diff --git a/tts/model/__pycache__/__init__.cpython-312.pyc b/tts/model/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a861e40669415a0110543fbf29f020337f36e4b2 Binary files /dev/null and b/tts/model/__pycache__/__init__.cpython-312.pyc differ diff --git a/tts/model/__pycache__/cache_utils.cpython-312.pyc b/tts/model/__pycache__/cache_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..679e727bc5bc04b093d39d1dc1d7bd7e83de130c Binary files /dev/null and b/tts/model/__pycache__/cache_utils.cpython-312.pyc differ diff --git a/tts/model/__pycache__/config.cpython-312.pyc b/tts/model/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58a395d83c439f80acfb01e4894c7cd9ca0b2218 Binary files /dev/null and b/tts/model/__pycache__/config.cpython-312.pyc differ diff --git a/tts/model/__pycache__/prediction_head.cpython-312.pyc b/tts/model/__pycache__/prediction_head.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48399463364d0eb4310728ce97d40a4e7d53e442 Binary files /dev/null and b/tts/model/__pycache__/prediction_head.cpython-312.pyc differ diff --git a/tts/model/__pycache__/registry.cpython-312.pyc b/tts/model/__pycache__/registry.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15f02dc37ee8dc6466c66e0624bd59b1fce61f3f Binary files /dev/null and b/tts/model/__pycache__/registry.cpython-312.pyc differ diff --git a/tts/model/__pycache__/shortconv.cpython-312.pyc b/tts/model/__pycache__/shortconv.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cb803d4c913d7e2f556eefd907b7bd3cee22d1d Binary files /dev/null and b/tts/model/__pycache__/shortconv.cpython-312.pyc differ diff --git a/tts/model/__pycache__/simple_gla.cpython-312.pyc b/tts/model/__pycache__/simple_gla.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ffa4c6694d8dbeea14d22c5051e914de51f2453 Binary files /dev/null and b/tts/model/__pycache__/simple_gla.cpython-312.pyc differ diff --git a/tts/model/__pycache__/transformer.cpython-312.pyc b/tts/model/__pycache__/transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05056791a82cda49b2c8a660ff452a3340b30430 Binary files /dev/null and b/tts/model/__pycache__/transformer.cpython-312.pyc differ diff --git a/tts/model/cache.py b/tts/model/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..6fd84372fab1fe59bf02c0a75c55482084558351 --- /dev/null +++ b/tts/model/cache.py @@ -0,0 +1,308 @@ +from typing import Any + +import torch +from transformers.cache_utils import Cache, _static_cache_update + + +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. If you are manually setting the batch size, make sure to take into account the + number of beams if you are running beam search + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + StaticCache() + ``` + """ + + is_compileable = True + + def __init__( + self, + max_batch_size: int, + head_dim: int, + num_key_value_heads: int, + num_hidden_layers: int, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.max_cache_len = max_cache_len + + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = head_dim + self._dtype = dtype + self.num_key_value_heads = num_key_value_heads + self.num_hidden_layers = num_hidden_layers + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + # Note: There will be significant perf decrease if switching to use 5D tensors instead. + cache_shape = ( + self.max_batch_size, + self.num_key_value_heads, + self.max_cache_len, + self.head_dim, + ) + device = torch.device(device) if device is not None else None + for idx in range(self.num_hidden_layers): + if layer_device_map is not None: + layer_device = layer_device_map[idx] + else: + layer_device = device + new_layer_key_cache = torch.zeros( + cache_shape, dtype=self._dtype, device=layer_device + ) + new_layer_value_cache = torch.zeros( + cache_shape, dtype=self._dtype, device=layer_device + ) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + if cache_kwargs is None: + cache_kwargs = {} + + key_states = key_states.to(self.key_cache[layer_idx].dtype) + value_states = value_states.to(self.value_cache[layer_idx].dtype) + return _static_cache_update( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + key_states, + value_states, + cache_kwargs.get("cache_position"), + ) + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + def get_mask_sizes( + self, cache_position: torch.Tensor, layer_idx: int + ) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + kv_length = self.get_max_cache_shape() + return kv_length, 0 + + +class Cache: + """ + A cache used for storing hidden states produced by flash linear attention models. + + It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`. + """ + + is_compileable = True + + def __init__(self, seen_tokens: int = 0) -> Cache: + super().__init__() + + self.states: list[dict[str, Any]] = [] + + self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen + + def __getitem__(self, layer_idx: int) -> dict[str, Any]: + if layer_idx < len(self): + return self.states[layer_idx] + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + for state in self.states: + yield state + + def __len__(self): + return len(self.states) + + def update( + self, + recurrent_state: torch.Tensor | None = None, + attn_state: tuple[torch.Tensor, torch.Tensor] | None = None, + conv_state: tuple[torch.Tensor] | None = None, + ffn_state: torch.Tensor | None = None, + layer_idx: int = 0, + offset: int | None = 1, + cache_kwargs: dict | None = None, + ): + """ + Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`. + + Args: + recurrent_state (`torch.Tensor`, `optional`): + The new recurrent state to cache. + attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`): + The new attention key/value states to cache. + conv_state (`Tuple[torch.Tensor]`, `optional`): + The new convolution state to cache. + layer_idx (`int`, defaults to 0): + The index of the layer to cache the states for. + offset (`int`, `optional`, defaults to 1): + The number of new tokens being processed. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. + + Return: + Dictionary of the updated state. + """ + + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += offset + + if attn_state is not None: + input_size = attn_state[0].shape[-2] + window_size = cache_kwargs.get("window_size", None) + if not isinstance(attn_state, Tuple) or len(attn_state) != 2: + raise ValueError( + "`attn_state` must be a tuple of two tensors for key/value states" + ) + if len(self.states) <= layer_idx: + if attn_state is not None: + if window_size is not None and input_size > window_size: + attn_state = ( + attn_state[0][..., -window_size:, :].contiguous(), + attn_state[1][..., -window_size:, :].contiguous(), + ) + state = dict( + recurrent_state=recurrent_state, + attn_state=attn_state, + conv_state=conv_state, + ffn_state=ffn_state, + ) + self.states.append(state) + else: + state = self.states[layer_idx] + if recurrent_state is not None: + state["recurrent_state"] = recurrent_state + if attn_state is not None: + key_state, value_state = state["attn_state"] + if window_size is not None and key_state.shape[-2] == window_size: + # DO NOT allocate new memory if the cache is full + # roll the key/value states to the left by `input_size` + key_state = key_state.roll(-input_size, -2) + value_state = value_state.roll(-input_size, -2) + # replace the last `input_size` tokens with the new key/value states + key_state[..., -input_size:, :] = attn_state[0] + value_state[..., -input_size:, :] = attn_state[1] + attn_state = (key_state, value_state) + else: + attn_state = ( + torch.cat([key_state, attn_state[0]], -2), + torch.cat([value_state, attn_state[1]], -2), + ) + state["attn_state"] = attn_state + if conv_state is not None: + state["conv_state"] = conv_state + if ffn_state is not None: + state["ffn_state"] = ffn_state + + return state + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.states) <= layer_idx: + return 0 + return self._seen_tokens + + def get_max_length(self) -> int | None: + """Returns the maximum sequence length of the cached states. Cache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> tuple: + return tuple(self.states) + + @classmethod + @torch.compiler.disable + def from_legacy_cache( + cls, past_key_values: tuple | None = None, seen_tokens: int = 0 + ) -> Cache: + """Converts a cache in the legacy cache format into an equivalent `Cache`.""" + + cache = cls(seen_tokens) + if isinstance(past_key_values, list): + for layer_idx in range(len(past_key_values)): + cache.states.append(past_key_values[layer_idx]) + return cache diff --git a/tts/model/cache_utils.py b/tts/model/cache_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0313c655c1fe839a49b4a1edd6ed0a3c3922c518 --- /dev/null +++ b/tts/model/cache_utils.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import transformers +from transformers.cache_utils import DynamicCache + +from tts.layers.attention import precompute_freqs_cis + + +class FLACache(transformers.cache_utils.Cache): + """ + A cache used for storing hidden states produced by flash linear attention models. + + It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`. + """ + + is_compileable = True + + def __init__( + self, + head_dim: int | None = None, + max_seq_len: int | None = None, + num_states: int | None = None, + device: str = "cuda", + seen_tokens: int = 0, + ): + super().__init__() + if head_dim is not None and max_seq_len is not None: + self.freqs = precompute_freqs_cis( + torch.arange(max_seq_len, device=device), head_dim + ) + self.states: List[Dict[str, Any]] = [] + if num_states is not None: + self.states = [ + dict( + recurrent_state=None, + attn_state=None, + conv_state=None, + short_conv_state=None, + ffn_state=None, + crossatt_state=None, + crossatt_weights=None, + ) + for _ in range(num_states) + ] + + self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen + + def __getitem__(self, layer_idx: int) -> Dict[str, Any]: + if layer_idx < len(self): + return self.states[layer_idx] + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + for state in self.states: + yield state + + def __len__(self): + return len(self.states) + + def update( + self, + recurrent_state: torch.Tensor = None, + attn_state: Tuple[torch.Tensor, torch.Tensor] = None, + conv_state: Tuple[torch.Tensor] = None, + short_conv_state: Tuple[torch.Tensor] = None, + crossatt_state: Tuple[torch.Tensor] = None, + crossatt_weights: Tuple[torch.Tensor] = None, + ffn_state: torch.Tensor = None, + layer_idx: int = 0, + offset: Optional[int] = 1, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`. + + Args: + recurrent_state (`torch.Tensor`, `optional`): + The new recurrent state to cache. + attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`): + The new attention key/value states to cache. + conv_state (`Tuple[torch.Tensor]`, `optional`): + The new convolution state to cache. + layer_idx (`int`, defaults to 0): + The index of the layer to cache the states for. + offset (`int`, `optional`, defaults to 1): + The number of new tokens being processed. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. + + Return: + Dictionary of the updated state. + """ + + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += offset + + if attn_state is not None: + input_size = attn_state[0].shape[-2] + window_size = ( + cache_kwargs.get("window_size", None) + if cache_kwargs is not None + else None + ) + if not isinstance(attn_state, Tuple) or len(attn_state) != 2: + raise ValueError( + "`attn_state` must be a tuple of two tensors for key/value states" + ) + if len(self.states) <= layer_idx: + if attn_state is not None: + if window_size is not None and input_size > window_size: + attn_state = ( + attn_state[0][..., -window_size:, :].contiguous(), + attn_state[1][..., -window_size:, :].contiguous(), + ) + state = dict( + recurrent_state=recurrent_state, + attn_state=attn_state, + conv_state=conv_state, + short_conv_state=short_conv_state, + ffn_state=ffn_state, + crossatt_state=crossatt_state, + crossatt_weights=crossatt_weights, + ) + self.states.append(state) + else: + state = self.states[layer_idx] + if recurrent_state is not None: + state["recurrent_state"] = recurrent_state + if crossatt_state is not None: + state["crossatt_state"] = crossatt_state + if crossatt_weights is not None: + if state["crossatt_weights"] is not None: + state["crossatt_weights"] = torch.cat( + (state["crossatt_weights"], crossatt_weights), dim=-2 + ) + else: + state["crossatt_weights"] = crossatt_weights + if attn_state is not None: + key_state, value_state = state["attn_state"] + if window_size is not None and key_state.shape[-2] == window_size: + # DO NOT allocate new memory if the cache is full + # roll the key/value states to the left by `input_size` + key_state = key_state.roll(-input_size, -2) + value_state = value_state.roll(-input_size, -2) + # replace the last `input_size` tokens with the new key/value states + key_state[..., -input_size:, :] = attn_state[0] + value_state[..., -input_size:, :] = attn_state[1] + attn_state = (key_state, value_state) + else: + attn_state = ( + torch.cat([key_state, attn_state[0]], -2), + torch.cat([value_state, attn_state[1]], -2), + ) + state["attn_state"] = attn_state + if conv_state is not None: + state["conv_state"] = conv_state + if short_conv_state is not None: + state["short_conv_state"] = short_conv_state + if ffn_state is not None: + state["ffn_state"] = ffn_state + + return state + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.states) <= layer_idx: + return 0 + return self._seen_tokens + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. Cache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> Tuple: + return tuple(self.states) + + @classmethod + @torch.compiler.disable + def from_legacy_cache( + cls, past_key_values: Optional[Tuple] = None, seen_tokens: int = 0 + ) -> Cache: + """Converts a cache in the legacy cache format into an equivalent `Cache`.""" + + cache = cls(seen_tokens) + if isinstance(past_key_values, list): + for layer_idx in range(len(past_key_values)): + cache.states.append(past_key_values[layer_idx]) + return cache + + +class TransformerDecoderCache(FLACache): + def __init__( + self, + head_dim: int, + max_seq_len: int, + device: str, + ): + super().__init__() + self.freqs = precompute_freqs_cis( + torch.arange(max_seq_len, device=device), head_dim + ) diff --git a/tts/model/config.py b/tts/model/config.py new file mode 100644 index 0000000000000000000000000000000000000000..34270fccb445f78d94e3fe6a6e1583babbcb3981 --- /dev/null +++ b/tts/model/config.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass, field +from typing import Literal + + +@dataclass +class SimpleGLADecoderConfig: + name: str = "simple_gla" + dim: int = 512 + use_short_conv = True + expand_k: float = 0.5 + expand_v: float = 1.0 + num_heads: int = 4 + num_layers: int = 12 + ffn_expansion_factor: int = 4 + conv_layers: list[int] | None = field(default_factory=lambda: [0, 2, 4, 6, 8, 10]) + blind_crossatt: bool = True + listen_read_crossatt: dict[int, list[Literal["listen", "read"]]] | None = None + crossatt_num_heads: int = 1 + crossatt_dropout: float = 0.1 + crossatt_layer_idx: list[int] = field(default_factory=lambda: [5]) + + +@dataclass +class TransformerEncoderConfig: + name: str = "sa_transformer" + dim: int = 512 + num_heads: int = 8 + num_layers: int = 6 + ffn_expansion_factor: int = 4 + + +@dataclass +class ConvFormerEncoderConfig: + name: str = "convformer_encoder" + dim: int = 512 + num_heads: int = 8 + num_transformer_layers: int = 6 + num_conv_layers: int = 3 + ffn_expansion_factor: int = 4 + + +@dataclass +class TransformerDecoderConfig: + name: str = "sa_transformer" + dim: int = 512 + num_heads: int = 8 + num_layers: int = 12 + conv_layers: list[int] | None = field(default_factory=lambda: [0, 2, 4, 6, 8, 10]) + ffn_expansion_factor: int = 4 + crossatt_dropout: float = 0.1 + crossatt_num_heads: int = 1 + crossatt_layer_idx: list[int] = field(default_factory=lambda: [5, 6]) + + +DecoderConfig = TransformerDecoderConfig | SimpleGLADecoderConfig +EncoderConfig = TransformerEncoderConfig | ConvFormerEncoderConfig + + +@dataclass +class TTSConfig: + dim: int = 512 + text_vocab_size: int = 256 + 3 + audio_vocab_size: int = 4096 + 3 + audio_embed_size: int = 16 + audio_input_type: Literal["continuous", "discrete"] = "continuous" + diffusion_head_num_layers: int = 3 + encoder_cfg: EncoderConfig = field(default_factory=TransformerEncoderConfig) + decoder_cfg: DecoderConfig = field(default_factory=TransformerDecoderConfig) + stop_prediction_head: bool = True + multi_stop_prediction_head: bool = False + multi_stop_num_tokens: int = 8 + multi_stop_padding_idx: int | None = None + multi_stop_tie_embd: bool = True + stop_token_embd: bool = False + text_stop_token_embd: bool = False + continuous_diffusion: bool = True + num_sink_tokens: int = 0 + disabled_crossatt_head_idx: list[tuple[int, int]] | None = None + patchvae_path: str | None = None + text_tokenizer_path: str | None = None + eos: bool = False + bos: bool = False + + +@dataclass +class PlayHeadConfig: + selected_cross_attention_heads: list[tuple[int, int]] + dim: int = 256 + num_layers: int = 6 + num_frame_lag: int = 2 + num_sink_tokens: int = 4 + cycle_len: int = 8 + logits_head: bool = True + target_lag: int = 0 + avg_pool_stride: int = 3 + circular_head: bool = False + + +@dataclass +class QueryVCConfig: + dim: int = 512 + semantic_dim: int = 512 + num_layers: int = 12 + lag: int = 4 + audio_embed_size: int = 8 + diffusion_head_num_layers: int = 3 diff --git a/tts/model/prediction_head.py b/tts/model/prediction_head.py new file mode 100644 index 0000000000000000000000000000000000000000..f90a59fb4db633659210f8d34483f35410ac6d9d --- /dev/null +++ b/tts/model/prediction_head.py @@ -0,0 +1,283 @@ +import os +import sys +from contextlib import contextmanager + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn +from torchdyn.core import NeuralODE + +from tts.layers.ffn import (AdaLNFinalLayer, AdaLNMLP, + GaussianFourierTimeEmbedding) + + +@contextmanager +def suppress_stdout(): + original_stdout = sys.stdout + try: + sys.stdout = open(os.devnull, "w") + yield + finally: + sys.stdout.close() + sys.stdout = original_stdout + + +def sample_from_logits( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, +) -> torch.Tensor: + B, N, C = logits.shape + logits = logits / temperature + + # Apply top-k + if top_k > 0: + top_k = min(top_k, C) + topk_values, _ = torch.topk(logits, top_k, dim=-1) + kth_value = topk_values[..., -1, None] + logits = torch.where( + logits < kth_value, torch.full_like(logits, float("-inf")), logits + ) + + # Apply top-p (nucleus) sampling + if top_p > 0.0 and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + probs = F.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(probs, dim=-1) + + # Create mask for tokens to remove + cutoff_mask = cumulative_probs > top_p + cutoff_mask[..., 0] = 0 # Always keep at least one token + sorted_logits[cutoff_mask] = float("-inf") + + # Map back to original logits shape + logits = torch.full_like(logits, float("-inf")).scatter( + -1, sorted_indices, sorted_logits + ) + + # Convert logits to probabilities + probs = F.softmax(logits, dim=-1) + + # Sample + samples = torch.multinomial(probs.view(-1, C), num_samples=1).view(B, N) + return samples + + +class LogitsHead(nn.Module): + def __init__(self, hidden_dim: int, vocab_size: int): + super().__init__() + self.logits_proj = nn.Linear(hidden_dim, vocab_size) + + def forward(self, pre_logits): + return self.logits_proj(pre_logits) + + def compute_loss( + self, + pre_logits: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, + ): + logits = self(pre_logits) + if mask is not None: + flat_logits = logits[mask] + flat_target = target[mask] + else: + flat_logits = rearrange(logits, "b n l -> (b n) l") + flat_target = rearrange(target, "b n -> (b n)") + + loss = nn.functional.cross_entropy( + flat_logits, + flat_target, + ) + + return {"cross_entropy": loss} + + def predict(self, x: torch.Tensor, *args, **kwargs): + return sample_from_logits(self(x), *args, **kwargs) + + +class ContinuousHead(nn.Module): + def __init__(self, hidden_dim: int, feature_dim: int): + super().__init__() + self.continuous_head = nn.Linear(hidden_dim, feature_dim) + + def forward(self, x: torch.Tensor): + return self.continuous_head(x) + + def predict(self, x: torch.Tensor): + return self(x) + + def compute_loss( + self, pre_logits: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + ): + if mask is not None: + pre_logits = pre_logits[mask] + target = target[mask] + return {"mse": nn.functional.mse_loss(self(pre_logits), target)} + + +class VelocityHead(nn.Module): + def __init__( + self, + hidden_dim: int, + feature_dim: int, + num_layers: int, + cond_dim: int | None = None, + ): + super().__init__() + cond_dim = cond_dim if cond_dim is not None else hidden_dim + self.feature_embed = nn.Linear(feature_dim, hidden_dim) + self.cond_embed = nn.Linear(cond_dim, hidden_dim) + self.time_embed = GaussianFourierTimeEmbedding(hidden_dim // 2) + self.adaln_mlp = nn.ModuleList( + [AdaLNMLP(hidden_dim) for _ in range(num_layers)] + ) + self.adaln_final_layer = AdaLNFinalLayer(hidden_dim, feature_dim) + self.feature_dim = feature_dim + + def forward( + self, + cond: torch.Tensor, + x: torch.Tensor, + t: torch.Tensor | None = None, + cond_drop_mask: torch.BoolTensor | None = None, + ): + cond = self.cond_embed(cond) + if cond_drop_mask is not None: + cond[cond_drop_mask] = 0.0 + cond += self.time_embed(t)[:, None] + x = self.feature_embed(x) + + for l in self.adaln_mlp: + x = l(x, cond) + y = self.adaln_final_layer(x, cond) + + return y + + def compute_loss( + self, + cond: torch.Tensor, + x1: torch.Tensor, + mask: torch.Tensor | None, + sigma: float = 1e-5, + t: torch.Tensor | None = None, + x0: torch.Tensor | None = None, + cfg_drop_rate: float = 0.1, + ): + """ + CFM Loss + """ + if t is None: + t = torch.rand(cond.shape[0], device=cond.device) + + if x0 is None: + x0 = torch.randn_like(x1, device=x1.device) + + flow_target = x1 - (1 - sigma) * x0 + alpha = (1 - (1 - sigma) * t).view(-1, 1, 1) + xt = alpha * x0 + t.view(-1, 1, 1) * x1 + + if self.training and cfg_drop_rate > 0.0: + cond_drop_mask = torch.rand(cond.shape[:2]) < cfg_drop_rate + else: + cond_drop_mask = None + + flow_pred = self(cond, xt, t, cond_drop_mask=cond_drop_mask) + + if mask is not None: + flow_pred = flow_pred[mask] + flow_target = flow_target[mask] + + loss = nn.functional.mse_loss(flow_pred, flow_target) + + return {"diffusion": loss} + + def predict( + self, + pre_prediction: torch.Tensor, + pre_prediction_ref: torch.Tensor | None = None, + solver: str = "euler", + sensitivity: str = "adjoint", + num_steps: int = 10, + cfg: float = 1.0, + cfg_ref: float = 1.5, + temperature: float = 1.0, + **kwargs, + ): + if cfg == 1.0: + + if pre_prediction_ref is None: + def solver_fn(t, Xt, *args, **kwargs): + return self(pre_prediction, Xt, t.unsqueeze(0)) + else: + raise NotImplementedError + else: + if pre_prediction_ref is None: + + def solver_fn(t, Xt, *args, **kwargs): + cond_uncond = torch.cat( + (pre_prediction, torch.zeros_like(pre_prediction)), + dim=0, + ) + cond_uncond = self(cond_uncond, Xt.repeat(2, 1, 1), t.unsqueeze(0)) + cond, uncond = cond_uncond.chunk(2, dim=0) + cond_uncond_cfg = uncond + cfg * (cond - uncond) + + return cond_uncond_cfg + else: + + def solver_fn(t, Xt, *args, **kwargs): + cond_uncond_ref = torch.cat((pre_prediction, pre_prediction_ref, torch.zeros_like(pre_prediction))) + #cond_uncond_ref, = torch.cat( + # (pre_prediction, torch.zeros_like(pre_prediction), pre_prediction_ref), + # dim=0, + #) + cond_uncond = self(cond_uncond_ref, Xt.repeat(3, 1, 1), t.unsqueeze(0)) + cond, ref, uncond = cond_uncond.chunk(3, dim=0) + #cond_uncond_cfg = uncond + cfg * (cond - uncond) + #cond_uncond_cfg_ref_cfg = ref + cfg_ref * (cond_uncond_cfg - ref) + cond_uncond_cfg = ref + cfg_ref * (cond - ref) + cond_uncond_cfg_ref_cfg = uncond + cfg * (cond_uncond_cfg - uncond) + + return cond_uncond_cfg_ref_cfg + + + # get rid of torchdyn warning + with suppress_stdout(): + node_ = NeuralODE(solver_fn, solver=solver, sensitivity=sensitivity) + t_span = torch.linspace(0, 1, num_steps + 1, device=pre_prediction.device) + traj = node_.trajectory( + torch.randn(pre_prediction.shape[0], 1, self.feature_dim, device=pre_prediction.device) + * temperature, + t_span=t_span, + ) + prediction = traj[-1] + return prediction + + +class StopPredictionHead(nn.Module): + def __init__(self, dim: int, weight_loss: float = 1.0): + super().__init__() + self.proj = nn.Linear(dim, 1) + self.weight_loss = weight_loss + + def forward(self, pre_prediction: torch.Tensor): + return torch.sigmoid(self.proj(pre_prediction)) + + def predict(self, pre_prediction: torch.Tensor): + return torch.sigmoid(self.proj(pre_prediction)) + + def compute_loss( + self, + pre_prediction: torch.Tensor, + target: torch.Tensor, + ): + logits = self.proj(pre_prediction) + bce = nn.functional.binary_cross_entropy_with_logits( + logits.squeeze(-1), + target.to(logits.dtype), + weight=torch.ones(logits.shape[0], device=logits.device) * self.weight_loss, + ) + return {"stop_bce": bce} diff --git a/tts/model/registry.py b/tts/model/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0a00c29717c2a7888a38cef182b1240068810762 --- /dev/null +++ b/tts/model/registry.py @@ -0,0 +1,18 @@ +ENCODER_REGISTRY = {} +DECODER_REGISTRY = {} + + +def register_encoder(name): + def wrapper(cls): + ENCODER_REGISTRY[name] = cls + return cls + + return wrapper + + +def register_decoder(name): + def wrapper(cls): + DECODER_REGISTRY[name] = cls + return cls + + return wrapper diff --git a/tts/model/shortconv.py b/tts/model/shortconv.py new file mode 100644 index 0000000000000000000000000000000000000000..4677c1bb58934a2f27aaccb01b55a1a645a09998 --- /dev/null +++ b/tts/model/shortconv.py @@ -0,0 +1,39 @@ +import torch +from fla.modules import ShortConvolution +from torch import nn + +from tts.layers.ffn import SwiGLU +from tts.model.cache_utils import FLACache + + +class ShortConvBlock(nn.Module): + def __init__( + self, + dim: int, + kernel_size: int, + ffn_expansion_factor: int, + layer_idx: int, + use_fast_conv1d: bool = True, + ): + super().__init__() + self.tmix = ShortConvolution(dim, kernel_size, use_fast_conv1d=use_fast_conv1d) + self.cmix = SwiGLU(dim, ffn_expansion_factor) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.layer_idx = layer_idx + + def forward(self, x: torch.Tensor, cache: FLACache | None = None, **kwargs): + last_state = None + if cache is not None and len(cache) > self.layer_idx: + last_state = cache[self.layer_idx]["short_conv_state"] + x, short_conv_state = self.tmix( + self.norm1(x), cache=last_state, output_final_state=cache is not None + ) + if cache is not None: + cache.update( + short_conv_state=short_conv_state, + layer_idx=self.layer_idx, + offset=x.shape[1], + ) + x = self.cmix(self.norm2(x)) + return x diff --git a/tts/model/simple_gla.py b/tts/model/simple_gla.py new file mode 100644 index 0000000000000000000000000000000000000000..9425b7dd12ec50385ea99f086250b749f746a57b --- /dev/null +++ b/tts/model/simple_gla.py @@ -0,0 +1,291 @@ +import os + +import torch +import torch.nn.functional as F +from einops import rearrange +from fla.layers.simple_gla import SimpleGatedLinearAttention +from fla.models.utils import Cache +from sympy import num_digits +from torch import nn + +from tts.layers.attention import CrossAttention +from tts.layers.ffn import SwiGLU + +from .cache_utils import FLACache +from .config import SimpleGLADecoderConfig +from .registry import register_decoder +from .shortconv import ShortConvBlock + +if "GRAD_CKPT" in os.environ: + + def maybe_grad_ckpt(f): + def grad_ckpt_f(*args, **kwargs): + return torch.utils.checkpoint.checkpoint( + f, *args, **kwargs, use_reentrant=False + ) + + return grad_ckpt_f +else: + + def maybe_grad_ckpt(f): + return f + + +class SimpleGLABlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + layer_idx: int, + expand_k: float, + expand_v: float, + use_short_conv: bool, + ffn_expansion_factor: int, + ): + super().__init__() + self.tmix = SimpleGatedLinearAttention( + hidden_size=dim, + num_heads=num_heads, + layer_idx=layer_idx, + ) + self.cmix = SwiGLU(dim, ffn_expansion_factor) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + def forward( + self, + x, + freqs: torch.Tensor | None = None, + text_freqs: torch.Tensor | None = None, + cache: Cache | None = None, + ): + x = ( + self.tmix( + self.norm1(x), + past_key_values=cache, + use_cache=cache is not None, + )[0] + + x + ) + x = self.cmix(self.norm2(x)) + x + return x + + +class DecoderBlockWithOptionalCrossAttention(nn.Module): + def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None): + super().__init__() + + self.decoder_block = decoder_block + self.crossatt = crossatt + + def forward( + self, + x: torch.Tensor, + encoder_output: torch.Tensor | None = None, + freqs: torch.Tensor | None = None, + text_freqs: torch.Tensor | None = None, + cache: Cache | None = None, + selfatt_mask: torch.Tensor | None = None, + crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None, + ) -> torch.Tensor: + x = self.decoder_block( + x, + freqs=freqs, + cache=cache, + ) + if type(crossatt_mask) is list: + crossatt_mask = crossatt_mask[self.decoder_block.tmix.layer_idx] + if self.crossatt is not None: + x = x + self.crossatt( + x, + k=encoder_output, + text_freqs=text_freqs, + mask=crossatt_mask, + cache=cache, + ) + + return x + + +@register_decoder("simple_gla") +class SimpleGLADecoder(nn.Module): + config = SimpleGLADecoderConfig + + def __init__(self, cfg: SimpleGLADecoderConfig): + super().__init__() + + assert cfg.dim % cfg.num_heads == 0, "num_heads should divide dim" + assert cfg.blind_crossatt + (cfg.listen_read_crossatt is not None) < 2, ( + "at most one specialized cross-attention" + ) + + self.head_dim = cfg.dim // cfg.num_heads + self.num_heads = cfg.num_heads + + def simple_gla_block(i): + conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers + if i in conv_layers: + return ShortConvBlock( + dim=cfg.dim, + kernel_size=4, + ffn_expansion_factor=cfg.ffn_expansion_factor, + layer_idx=i, + use_fast_conv1d=True, + ) + + else: + return SimpleGLABlock( + dim=cfg.dim, + num_heads=cfg.num_heads, + layer_idx=i, + expand_k=cfg.expand_k, + expand_v=cfg.expand_v, + use_short_conv=cfg.use_short_conv, + ffn_expansion_factor=cfg.ffn_expansion_factor, + ) + + def crossatt_block(i): + if i in cfg.crossatt_layer_idx: + return CrossAttention( + dim=cfg.dim, + num_heads=cfg.crossatt_num_heads, + dropout=cfg.crossatt_dropout, + layer_idx=i, + ) + else: + return None + + self.decoder_layers = nn.ModuleList( + [ + DecoderBlockWithOptionalCrossAttention( + simple_gla_block(i), + crossatt_block(i), + ) + for i in range(cfg.num_layers) + ] + ) + + def forward( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None, + text_ids: torch.Tensor | None = None, + cache: FLACache | None = None, + ): + x = decoder_input + text_freqs = None + + for layer in self.decoder_layers: + x = maybe_grad_ckpt(layer)( + x, + encoder_output, + text_freqs=text_freqs, + cache=cache, + crossatt_mask=crossatt_mask, + ) + return x + + def init_cache(self, max_seq_len, device): + return FLACache(num_states=len(self.decoder_layers) + 1) + + def init_initial_state(self, batch_size=1, scale=1e-2, device="cpu"): + return tuple( + nn.Parameter( + torch.randn( + batch_size, + self.num_heads, + self.head_dim, + self.head_dim, + device=device, + ) + * scale + ) + for _ in range(len(self.decoder_layers)) + ) + def init_initial_state_lora(self, lora:int=1, batch_size: int = 1, scale: float=1e-2, device: str="cpu"): + return tuple( + ( + nn.Parameter( + torch.randn( + batch_size, + self.num_heads, + self.head_dim, + lora, + device=device, + ) + * scale + ), + nn.Parameter( + torch.randn( + batch_size, + self.num_heads, + lora, + self.head_dim, + device=device, + ) + * scale + ) + ) + for _ in range(len(self.decoder_layers)) + ) + + def _get_query(self, audio_inputs: torch.Tensor, layer_idx: int): + assert self.decoder_layers[layer_idx].crossatt is not None + x = audio_inputs + for _, layer in zip(range(layer_idx - 1), self.decoder_layers): + x = layer(x, None) + return self.decoder_layers[layer_idx].crossatt._query(x) + + def forward_first_n_layers( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + n_first_layers: int, + crossatt_mask: torch.Tensor | None = None, + cache: FLACache | None = None, + ): + x = decoder_input + if self.text_freqs_embd is not None: + text_freqs = torch.arange(encoder_output.shape[1], device=x.device)[None, :] + text_freqs = self.text_freqs_embd(text_freqs) + else: + text_freqs = None + + for layer in self.decoder_layers[:n_first_layers]: + x = maybe_grad_ckpt(layer)( + x, + encoder_output, + text_freqs=text_freqs, + cache=cache, + crossatt_mask=crossatt_mask, + ) + return x + + def prefill( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + crossatt_mask: torch.Tensor | None = None, + cache: FLACache | None = None, + ): + return self(encoder_output, decoder_input, cache=cache, crossatt_mask=crossatt_mask) + + def decode_one( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + cache: Cache, + text_freqs: torch.Tensor | None = None, + crossatt_mask: torch.Tensor | None = None, + ): + x = decoder_input + for layer in self.decoder_layers: + x = layer( + x, + encoder_output, + text_freqs=text_freqs, + cache=cache, + crossatt_mask=crossatt_mask, + ) + return x diff --git a/tts/model/transformer.py b/tts/model/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ec1bc0e603a4845c89e2225a68ef264ce9b204 --- /dev/null +++ b/tts/model/transformer.py @@ -0,0 +1,282 @@ +import os + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn +from transformers.cache_utils import DynamicCache + +from tts.layers.attention import (CrossAttention, SelfAttention, + precompute_freqs_cis) +from tts.layers.conv import ConvNeXtBlock +from tts.layers.ffn import SwiGLU +from tts.model.cache_utils import FLACache, TransformerDecoderCache +from tts.model.config import (ConvFormerEncoderConfig, + TransformerDecoderConfig, + TransformerEncoderConfig) +from tts.model.registry import register_decoder, register_encoder +from tts.model.shortconv import ShortConvBlock + +if "GRAD_CKPT" in os.environ: + + def maybe_grad_ckpt(f): + def grad_ckpt_f(*args, **kwargs): + return torch.utils.checkpoint.checkpoint( + f, *args, **kwargs, use_reentrant=False + ) + + return grad_ckpt_f +else: + + def maybe_grad_ckpt(f): + return f + + +class TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + layer_idx: int, + ffn_expansion_factor: int, + is_causal: bool, + ): + super().__init__() + self.tmix = SelfAttention(dim, num_heads, layer_idx, is_causal=is_causal) + self.cmix = SwiGLU(dim, ffn_expansion_factor) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + def forward( + self, + x, + freqs: torch.Tensor | None = None, + cache: DynamicCache | None = None, + mask: torch.Tensor | None = None, + ): + x = self.tmix(self.norm1(x), freqs=freqs, cache=cache, mask=mask) + x + x = self.cmix(self.norm2(x)) + x + return x + + +class DecoderBlockWithOptionalCrossAttention(nn.Module): + def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None): + super().__init__() + + self.decoder_block = decoder_block + self.crossatt = crossatt + + def forward( + self, + x: torch.Tensor, + encoder_output: torch.Tensor | None = None, + freqs: torch.Tensor | None = None, + cache: FLACache | None = None, + selfatt_mask: torch.Tensor | None = None, + crossatt_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + x = self.decoder_block( + x, + freqs=freqs, + cache=cache, + ) + + if self.crossatt is not None: + x = x + self.crossatt( + x, + k=encoder_output, + mask=crossatt_mask, + cache=cache, + ) + + return x + + +@register_decoder("sa_transformer") +class TransformerDecoder(nn.Module): + config = TransformerDecoderConfig + def __init__(self, cfg: TransformerDecoderConfig): + super().__init__() + + assert cfg.dim % cfg.num_heads == 0 + self.head_dim = cfg.dim // cfg.num_heads + + def transformer_block(i): + conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers + if i in conv_layers: + return ShortConvBlock( + dim=cfg.dim, + kernel_size=4, + ffn_expansion_factor=cfg.ffn_expansion_factor, + layer_idx=i, + use_fast_conv1d=True, + ) + else: + return TransformerBlock( + dim=cfg.dim, + num_heads=cfg.num_heads, + layer_idx=i, + ffn_expansion_factor=cfg.ffn_expansion_factor, + is_causal=True, + ) + + def crossatt_block(i): + return CrossAttention( + dim=cfg.dim, + num_heads=cfg.crossatt_num_heads, + dropout=cfg.crossatt_dropout, + layer_idx=i, + ) + + self.decoder_layers = nn.ModuleList( + [ + DecoderBlockWithOptionalCrossAttention( + transformer_block(i), + crossatt_block(i) if i in cfg.crossatt_layer_idx else None, + ) + for i in range(cfg.num_layers) + ] + ) + + def forward( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + crossatt_mask: torch.Tensor | None = None, + cache: FLACache | None = None, + ): + x = decoder_input + positions = torch.arange(x.shape[1], device=x.device) + freqs = precompute_freqs_cis(positions, self.head_dim) + for layer in self.decoder_layers: + x = maybe_grad_ckpt(layer)( + x, + encoder_output, + freqs=freqs, + crossatt_mask=crossatt_mask, + cache=cache, + ) + return x + + def init_cache(self, max_seq_len, device): + return FLACache(head_dim=self.head_dim, max_seq_len=max_seq_len, device=device) + + def prefill( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + cache: FLACache | None = None, + ): + return self(encoder_output, decoder_input, cache=cache) + + def decode_one( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + cache: FLACache, + crossatt_mask: torch.Tensor | None = None, + ): + x = decoder_input + pos = cache._seen_tokens + freq = cache.freqs[:, [pos]] + for layer in self.decoder_layers: + x = layer( + x, + encoder_output, + freqs=freq, + cache=cache, + crossatt_mask=crossatt_mask, + ) + return x + + +@register_encoder("sa_transformer") +class TransformerEncoder(nn.Module): + config = TransformerEncoderConfig + + def __init__(self, cfg: TransformerEncoderConfig): + super().__init__() + assert cfg.dim % cfg.num_heads == 0 + self.head_dim = cfg.dim // cfg.num_heads + self.encoder_layers = nn.ModuleList( + [ + TransformerBlock( + cfg.dim, + cfg.num_heads, + i, + cfg.ffn_expansion_factor, + is_causal=False, + ) + for i in range(cfg.num_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None = None, + ): + positions = torch.arange(x.shape[1], device=x.device) + freqs = precompute_freqs_cis(positions, self.head_dim) + if mask is not None: + mask = rearrange(mask, "b n m -> b 1 n m") + mask = torch.logical_or( + mask, + rearrange(torch.eye(mask.shape[-1], device=x.device), "n m -> 1 1 n m"), + ) + + for layer in self.encoder_layers: + x = layer(x, freqs=freqs, mask=mask) + return x + + +@register_encoder("convformer_encoder") +class ConvFormerEncoder(nn.Module): + config = ConvFormerEncoderConfig + + def __init__(self, cfg: ConvFormerEncoderConfig): + super().__init__() + assert cfg.dim % cfg.num_heads == 0 + self.head_dim = cfg.dim // cfg.num_heads + self.conv_layers = nn.ModuleList( + [ConvNeXtBlock(cfg.dim) for _ in range(cfg.num_conv_layers)] + ) + self.encoder_layers = nn.ModuleList( + [ + TransformerBlock( + cfg.dim, + cfg.num_heads, + i, + cfg.ffn_expansion_factor, + is_causal=False, + ) + for i in range(cfg.num_transformer_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None = None, + text_rel_pos: torch.Tensor | None = None, + ): + if text_rel_pos is None: + text_rel_pos = torch.arange(x.shape[1], device=x.device) + freqs = precompute_freqs_cis(text_rel_pos, self.head_dim) + else: + freqs = precompute_freqs_cis(text_rel_pos, self.head_dim).unsqueeze(1) + + x = x.transpose(1, 2) + for layer in self.conv_layers: + x = layer(x) + x = x.transpose(1, 2) + if mask is not None: + mask = rearrange(mask, "b n m -> b 1 n m") + mask = torch.logical_or( + mask, + rearrange(torch.eye(mask.shape[-1], device=x.device), "n m -> 1 1 n m"), + ) + + for layer in self.encoder_layers: + x = layer(x, freqs=freqs, mask=mask) + return x diff --git a/tts/playhead.py b/tts/playhead.py new file mode 100644 index 0000000000000000000000000000000000000000..f716be24f61cd7215685c3fbc81a4b92e859723c --- /dev/null +++ b/tts/playhead.py @@ -0,0 +1,103 @@ +import torch +from torch import nn + +from model.cache_utils import FLACache +from model.config import PlayHeadConfig +from model.prediction_head import CircularHead, LogitsHead +from model.simple_gla import SimpleGLABlock + + +class PlayHead(nn.Module): + def __init__(self, cfg: PlayHeadConfig): + super().__init__() + self.cycle_len = cfg.cycle_len + self.num_sink_tokens = cfg.num_sink_tokens + self.pos_embedding = nn.Embedding(cfg.num_sink_tokens + cfg.cycle_len, cfg.dim) + self.avg_pool_stride = cfg.avg_pool_stride + self.net = nn.ModuleList( + [ + SimpleGLABlock( + dim=cfg.dim, + num_heads=cfg.dim // 128, + layer_idx=i, + expand_k=0.5, + expand_v=1.0, + use_short_conv=True, + ffn_expansion_factor=4, + ) + for i in range(cfg.num_layers) + ] + ) + + self.logits_head = ( + LogitsHead(cfg.dim, cfg.cycle_len) if cfg.logits_head else None + ) + self.circular_head = CircularHead(cfg.dim) if cfg.circular_head else None + + def forward( + self, + cross_attention_weights: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, + ): + B, T, A = cross_attention_weights.shape + # if self.cross_attention_reduction == "sum": + # cross_attention_weights = cross_attention_weights.sum(1) + + device = cross_attention_weights.device + pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len + sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len + sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None]) + x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd + for block in self.net: + x = block(x) + + losses = dict() + if self.logits_head is not None: + losses |= self.logits_head.compute_loss(x, target.long(), mask=mask) + if self.circular_head is not None: + losses |= self.circular_head.compute_loss(x, target, mask=mask) + + return losses + + def init_cache(self): + return FLACache(num_states=len(self.net)) + + def predict( + self, + cross_attention_weights: torch.Tensor, + previous_position: torch.Tensor | None = None, + cache: FLACache | None = None, + ): + avg_pool_ca = torch.nn.functional.avg_pool1d( + cross_attention_weights[:, self.num_sink_tokens :].transpose(-1, -2), + self.avg_pool_stride, + stride=self.avg_pool_stride, + ceil_mode=True, + ).transpose(-1, -2) + + sink_ca = cross_attention_weights[:, : self.num_sink_tokens] + cross_attention_weights = torch.cat((sink_ca, avg_pool_ca), dim=1) + + B, T, A = cross_attention_weights.shape + device = cross_attention_weights.device + pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len + sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len + sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None]) + x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd + for block in self.net: + x = block(x, cache=cache) + if self.logits_head is not None: + logits = self.logits_head(x) + pred_position = torch.argmax(logits, -1) + + if previous_position is not None: + current_angle, previous_angle = map( + lambda x: torch.exp(1j * 2 * torch.pi * x / self.cycle_len), + (pred_position, previous_position), + ) + diff = current_angle / previous_angle + step = (diff.angle() / (2 * torch.pi / self.cycle_len)).round().long() + return pred_position, step + + return pred_position diff --git a/tts/text_processor.py b/tts/text_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..322852d11ec901a84e659b8e46c3bed81b4bd9cb --- /dev/null +++ b/tts/text_processor.py @@ -0,0 +1,79 @@ +from transformers import PreTrainedTokenizerFast +import re +import unicodedata + +REPLACEMENTS = { + # whitespace + "\n": " ", + "\r": " ", + "\t": " ", + "\xa0": " ", + "\u2009": " ", + "\u202f": " ", + "\u200b": "", + # quotes + "‘": "'", + "’": "'", + "‚": "'", + "‛": "'", + "“": '"', + "”": '"', + "„": '"', + "«": '"', + "»": '"', + # dashes + "–": "-", + "—": "-", + "−": "-", + "-": "-", + # ellipsis + "…": "...", + # bullets & symbols + "•": ".", + "∙": ".", + "·": ".", + # currencies + "€": " euros", + "$": " dollars", + "£": " pounds", + "¥": " yen", + # misc + "°": " degrees", + "©": "", + "®": "", + "™": "", +} + + +def clean_text(text: str) -> str: + text = unicodedata.normalize("NFKC", text) + for src, tgt in REPLACEMENTS.items(): + text = text.replace(src, tgt) + text = re.sub(r"\s+", " ", text.strip()) # collapse spaces + return text + + +class BasicTextProcessor: + """ + Basic text processor on top of a character level BPE model. + """ + + def __init__(self, tokenizer_file: str): + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) + + def normalize(self, text: str) -> str: + """Basic pre-normalization: whitespace cleanup, punctuation spacing, etc.""" + text = clean_text(text) + text = text.strip() + return text + + def __call__(self, text: str, **tokenizer_kwargs): + """Normalize then tokenize.""" + text = self.normalize(text) + return self.tokenizer.encode(text, **tokenizer_kwargs) + + def detokenize(self, token_ids): + """Optional: convert back to string.""" + out = self.tokenizer.decode(token_ids, skip_special_tokens=False) + whitespace = "##[WHITESPACE]" + return out.replace(" ", whitespace).replace(" ", "").replace(whitespace, " ") diff --git a/tts/tools.py b/tts/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e9862e5456386902f0a64abe191483a282f27e --- /dev/null +++ b/tts/tools.py @@ -0,0 +1,250 @@ +from itertools import accumulate +from typing import Callable, List, Optional + +import torch +import torch.nn.functional as F + +default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def widen_alignment( + alignment: torch.Tensor, width: int | tuple[int, int], axis: str = "S" +) -> torch.Tensor: + """ + Widen 1-bands along one axis of an alignment matrix. + + Args: + alignment: (B, T, S) binary/bool/int tensor + width: int or (left, right) expansion + e.g. 2 -> expand ±2 + (1,3) -> expand -1 on the left, +3 on the right + axis: "S" to widen horizontally (across S), + "T" to widen vertically (across T) + + Returns: + (B, T, S) tensor with widened 1-bands along the chosen axis + """ + assert axis in ("S", "T") + orig_dtype = alignment.dtype + dev = alignment.device + + # normalize widths + if isinstance(width, int): + left, right = width, width + else: + left, right = width + ksize = left + right + 1 + kernel = torch.ones(1, 1, ksize, device=dev) + + if axis == "S": + # (B*T, 1, S) + x = alignment.view(-1, 1, alignment.size(-1)).float() + x = F.pad(x, (left, right)) # explicit asymmetric padding + y = F.conv1d(x, kernel) + y = (y > 0).view_as(alignment) + + else: # axis == "T" + # (B*S, 1, T) + x = ( + alignment.permute(0, 2, 1) + .contiguous() + .view(-1, 1, alignment.size(1)) + .float() + ) + x = F.pad(x, (left, right)) + y = F.conv1d(x, kernel) + # Back to (B, T, S) + y = ( + (y > 0) + .view(alignment.size(0), alignment.size(2), alignment.size(1)) + .permute(0, 2, 1) + ) + + # Cast back to original dtype + if orig_dtype == torch.bool: + return y + elif orig_dtype.is_floating_point: + return y.to(orig_dtype) + else: + return y.to(orig_dtype) + + +def collect_heads(cache, selected_heads): + return torch.stack( + [ + cache[layer]["crossatt_weights"][:, [head], [-1]] + for layer, head in selected_heads + ], + dim=1, + ) + + +def expand(x, r): + b, n, d = x.shape + x = x.unsqueeze(-1).repeat(1, 1, 1, r).reshape(b, n, r * d) + return x + + +def path_matrix(positions: torch.Tensor, num_positions: int = None) -> torch.Tensor: + if num_positions is None: + num_positions = positions.max().item() + 1 + return F.one_hot(positions, num_classes=num_positions).to(torch.int) + + +def pad_2d_sequence(seq, padding_value=0): + max_x, max_y = map(max, zip(*map(lambda x: x.shape, seq))) + pad = lambda x: torch.nn.functional.pad( + x, + (0, max_y - x.shape[1], 0, max_x - x.shape[0]), + value=padding_value, + ) + return torch.stack([pad(x) for x in seq]) + + +def audio_to_text_partial_neighbor_mask( + xlen, + ylen, + *, + past_tokens: int = 0, + future_tokens: int = 0, + device=None, + dtype=torch.bool, +): + """ + Build an (audio_len, text_len) boolean mask where True = allowed to attend. + Each audio frame (group g) can attend: + - all tokens of text group g (aligned word), + - last `past_tokens` tokens of text group g-1 (previous word), + - first `future_tokens` tokens of text group g+1 (next word). + + Args: + xlen (list[int]): token counts per text word (groups), e.g. [2,1,3] + ylen (list[int]): frame counts per audio word (aligned groups), e.g. [4,2,5] + past_tokens (int): allow up to this many tokens from end of previous word + future_tokens (int): allow up to this many tokens from start of next word + device: torch device + dtype: output dtype (bool by default) + + Returns: + mask: (A, T) boolean tensor (A = sum(ylen), T = sum(xlen)) + """ + if len(xlen) != len(ylen): + raise ValueError(f"len(xlen)={len(xlen)} must equal len(ylen)={len(ylen)}") + if any(l <= 0 for l in xlen) or any(l <= 0 for l in ylen): + raise ValueError("All lengths must be positive.") + if past_tokens < 0 or future_tokens < 0: + raise ValueError("past_tokens and future_tokens must be >= 0.") + + n = len(xlen) + + # Text-side: group id per token and position within its group + x_groups = torch.arange(n, device=device).repeat_interleave( + torch.tensor(xlen, device=device) + ) # (T,) + pos_in_group = torch.cat([torch.arange(L, device=device) for L in xlen]) # (T,) + # tokens from the end (0 for last token, 1 for second-to-last, ...) + pos_from_end = torch.cat( + [torch.arange(L - 1, -1, -1, device=device) for L in xlen] + ) # (T,) + + T = x_groups.numel() + + # Audio-side: group id per frame + y_groups = torch.arange(n, device=device).repeat_interleave( + torch.tensor(ylen, device=device) + ) # (A,) + A = y_groups.numel() + + # Broadcast to (A, T) + G_audio = y_groups[:, None] # (A, 1) + G_text = x_groups[None, :] # (1, T) + + # Conditions: + # 1) aligned word: all tokens + aligned = G_text == G_audio + + # 2) previous word: last `past_tokens` tokens only + if past_tokens > 0: + prev_group = G_text == (G_audio - 1) + prev_tail = pos_from_end[None, :] < past_tokens + prev_ok = prev_group & prev_tail + else: + prev_ok = torch.zeros((A, T), dtype=torch.bool, device=device) + + # 3) next word: first `future_tokens` tokens only + if future_tokens > 0: + next_group = G_text == (G_audio + 1) + next_head = pos_in_group[None, :] < future_tokens + next_ok = next_group & next_head + else: + next_ok = torch.zeros((A, T), dtype=torch.bool, device=device) + + mask = (aligned | prev_ok | next_ok).to(dtype=dtype) + return mask + + +def packmask_2d(xlen: list[int], ylen: list[int], offset: int = 0) -> torch.Tensor: + _, ybound = map(lambda x: [0] + list(accumulate(x, int.__add__)), (xlen, ylen)) + lb, hb = [], [] + + for n, l, h in zip(xlen, ybound[:-1], ybound[1:]): + lb += [l] * n + hb += [h] * n + + lb, hb = map(torch.tensor, (lb, hb)) + if offset: + lb -= offset + hb += offset + + rge = torch.arange(ybound[-1]) + + lm = rge.unsqueeze(0) >= lb.unsqueeze(1) + hm = rge.unsqueeze(0) < hb.unsqueeze(1) + + return lm * hm + + +def topk_sampling(seq, k=1, temp=1.0): + topk = torch.topk(seq, k, dim=-1) + logits = seq / temp + mask = logits < topk.values[:, [-1]] + logits[mask] = -float("Inf") + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1) + + +def delay_rvq( + code, + head_token: int = -2, + tail_token: int = -3, +): + q, _ = code.shape + extension = torch.ones((q, q + 1)).tril() * head_token + extension += torch.ones((q + 1, q)).tril(diagonal=-1).T * tail_token + extension = torch.flip(extension, (1,)) + extended_code = torch.cat((code, extension), axis=1) + for i in range(q): + extended_code[i, :] = torch.roll(extended_code[i, :], i + 1) + + return extended_code.long() + + +def undelay_rvq(extended_code): + q, _, n = extended_code.shape + out = [] + for i in range(q): + out.append(torch.roll(extended_code[i], -(i + 1), dims=1)) + out = torch.stack(out, dim=0) + return out[:, :, : -(q + 1)] + + +def sequence_mask(lengths, max_len=None, **kwargs): + batch_size = lengths.shape[0] + device = lengths.device + if max_len is None: + max_len = torch.max(lengths).item() + + ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) + mask = ids < lengths.unsqueeze(1).expand(-1, max_len) + + return mask diff --git a/tts/train_playhead.py b/tts/train_playhead.py new file mode 100644 index 0000000000000000000000000000000000000000..b89558b0558487583b8b048cf3f7b49a557a6665 --- /dev/null +++ b/tts/train_playhead.py @@ -0,0 +1,203 @@ +import math + +import hydra +import pytorch_lightning as ptl +import torch +from omegaconf import DictConfig +from super_monotonic_align import maximum_path +from torch.optim.lr_scheduler import LambdaLR + +from model.config import PlayHeadConfig +from playhead import PlayHead +from train_tts import TrainARTTS + + +def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): + def lr_lambda(step): + if step < warmup_steps: + return step / max(1, warmup_steps) + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) + return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr + + return lr_lambda + + +def expand(x, r): + b, n, d = x.shape + return x.unsqueeze(2).repeat(1, 1, r, 1).reshape(b, r * n, d) + + +class TrainPlayHead(ptl.LightningModule): + def __init__( + self, + tts_checkpoint_path: str, + playhead_config: PlayHeadConfig, + learning_rate: float = 5e-4, + end_learning_rate: float | None = None, + weight_decay: float = 0.1, + betas: tuple[float, float] = (0.9, 0.999), + n_warmup_steps: int = 500, + n_training_steps: int = 300000, + ): + super(TrainPlayHead, self).__init__() + + cfg = playhead_config + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.betas = betas + self.n_warmup_steps = n_warmup_steps + self.n_training_steps = n_training_steps + self.selected_cross_attention_heads = cfg.selected_cross_attention_heads + self.avg_pool_stride = cfg.avg_pool_stride + self.target_lag = cfg.target_lag + + self.save_hyperparameters() + + self.model = PlayHead(playhead_config) + tts_lightning_module = TrainARTTS.load_from_checkpoint(tts_checkpoint_path) + self.tts_model = tts_lightning_module.model.eval() + for p in self.tts_model.parameters(): + p.requires_grad = False + + def on_train_epoch_start(self): + if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"): + self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch) + + def save_model_weights_and_config( + self, + dir: str | None, + model_filename: str = "model.st", + config_filename: str = "config.json", + ): + # cfg = self.hparams.config + # Path(dir).mkdir(exist_ok=True) + # model_path = Path(dir) / model_filename + # save_file(self.model.state_dict(), model_path) + # with open(Path(dir) / config_filename, "w") as f: + # json.dump(asdict(cfg), f, indent=2) + pass + + def step(self, batch, batch_idx: int, validation: bool = False): + text_token = batch["text_token"] + audio_token = batch["audio_token"].squeeze(2) + crossatt_mask = batch["crossatt_mask"] + text_rel_pos = batch["text_rel_pos"] + encoder_mask = batch["encoder_mask"] + stop_token = batch.get("stop_token") + text_stop_token = batch.get("text_stop_token") + crossatt_rel_pos = batch.get("crossatt_rel_pos") + logits_mask = batch["y_mask"] + + with torch.inference_mode(): + _ = self.tts_model( + text_ids=text_token, + audio_inputs=audio_token, + text_mask=encoder_mask, + audio_mask=logits_mask, + crossatt_mask=crossatt_mask, + crossatt_rel_pos=crossatt_rel_pos, + stop_tokens=stop_token, + text_rel_pos=text_rel_pos, + text_stop_tokens=text_stop_token, + ) + + atts = [] + + for l in self.tts_model.audio_decoder.decoder_layers: + if l.crossatt is not None: + atts.append(l.crossatt.att) + + num_sinks = self.tts_model.num_sink_tokens + selected_ca_heads = torch.stack( + [ + atts[i][:, j].transpose(-1, -2) + for i, j in self.selected_cross_attention_heads + ] + ) + + summed_ca = selected_ca_heads.sum(0) + + avg_pool_ca = torch.nn.functional.avg_pool1d( + summed_ca[:, num_sinks:].transpose(-1, -2), + self.avg_pool_stride, + stride=self.avg_pool_stride, + ceil_mode=True, + ).transpose(-1, -2) + + mas_from_avg_pool = maximum_path( + avg_pool_ca.clone(), + mask=crossatt_mask[:, :-1, :: self.avg_pool_stride].transpose(-1, -2), + ) + target = torch.arange(mas_from_avg_pool.shape[1]).to(mas_from_avg_pool.device) + if self.target_lag > 0: + lag = self.target_lag + mas_from_avg_pool = torch.roll(mas_from_avg_pool, lag, dims=2) + mas_from_avg_pool[:, 0, :lag] = 1.0 + mas_from_avg_pool[:, 1:, :lag] = 0.0 + # logits_mask[:, :lag] = False + target = (mas_from_avg_pool * target[:, None]).max(dim=1).values + + sink_ca = summed_ca[:, :num_sinks] + + input_ca = torch.cat((sink_ca, avg_pool_ca), dim=1) + target = target % self.model.cycle_len + + return self.model(input_ca, target, logits_mask[:, :-1]), input_ca, target + + def training_step(self, batch, idx): + losses, _, _ = self.step(batch, idx) + total_loss = 0.0 + for name, loss in losses.items(): + self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True) + total_loss += loss + self.log("train_loss", total_loss, prog_bar=True, sync_dist=True) + return total_loss + + def validation_step(self, batch, idx): + losses, _, _ = self.step(batch, idx) + total_loss = 0.0 + for name, loss in losses.items(): + self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True) + total_loss += loss + self.log("val_loss", total_loss, prog_bar=True, sync_dist=True) + return total_loss + + def configure_optimizers(self): + params = [ + { + "params": self.model.parameters(), + "weight_decay": self.weight_decay, + } + ] + opt = torch.optim.AdamW( + params, + lr=self.learning_rate, + betas=self.betas, + ) + scheduler = LambdaLR( + opt, + lr_lambda=cosine_schedule_with_warmup( + warmup_steps=self.hparams.n_warmup_steps, + total_steps=self.hparams.n_training_steps, + start_lr=self.hparams.learning_rate, + end_lr=self.hparams.learning_rate * 0.1, + ), + ) + return [opt], [{"scheduler": scheduler, "interval": "step"}] + + +@hydra.main(config_path="playhead_configs/", config_name="config", version_base="1.3") +def main(cfg: DictConfig): + ptl.seed_everything(cfg.seed_everything) + + model = hydra.utils.instantiate(cfg.model) + cfg.experiment_name = f"PlayHead" + datamodule = hydra.utils.instantiate(cfg.data) + trainer = hydra.utils.instantiate(cfg.trainer) + + trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path")) + + +if __name__ == "__main__": + main() diff --git a/tts/train_tts.py b/tts/train_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..76fa749b0c257d52ee5521adaa84cfbf0a92bad0 --- /dev/null +++ b/tts/train_tts.py @@ -0,0 +1,227 @@ +import json +import math +from dataclasses import asdict +from pathlib import Path + +import hydra +import numpy as np +import pytorch_lightning as ptl +import torch +from omegaconf import DictConfig, ListConfig, OmegaConf +from safetensors.torch import save_file +from torch import nn +from torch.optim.lr_scheduler import LambdaLR +from transformers import get_cosine_schedule_with_warmup + +from .model.config import TTSConfig +from .model.prediction_head import VelocityHead +from .tts import ARTTSModel + + +def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): + def lr_lambda(step): + if step < warmup_steps: + return step / max(1, warmup_steps) + progress = min((step - warmup_steps) / max(1, total_steps - warmup_steps), 1) + cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) + return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr + + return lr_lambda + + +class TrainARTTS(ptl.LightningModule): + def __init__( + self, + config: TTSConfig, + quant_layer: list[int], + tie_embed: bool = False, + learning_rate: float = 5e-4, + end_learning_rate: float | None = None, + weight_decay: float = 0.1, + betas: tuple[float, float] = (0.9, 0.999), + n_warmup_steps: int = 500, + n_training_steps: int = 300000, + mask_text_p: float = 0.0, + load_weights: str | None = None, + stop_token_weight: float | None = None, + stop_loss_factor: float = 0.1, + stop_loss_warmup: tuple[int, int] | None = None, + ): + super(TrainARTTS, self).__init__() + + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.betas = betas + self.n_warmup_steps = n_warmup_steps + self.n_training_steps = n_training_steps + self.stop_token_weight = stop_token_weight + self.stop_loss_factor = stop_loss_factor + + self.save_hyperparameters() + + self.model = ARTTSModel(config) + + if load_weights is not None: + model = torch.load(load_weights) + self.load_state_dict(model["state_dict"], strict=False) + + def on_train_epoch_start(self): + if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"): + self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch) + + def save_model_weights_and_config( + self, + dir: str | None, + model_filename: str = "model.st", + config_filename: str = "config.json", + ): + def to_builtin(obj): + if isinstance(obj, dict): + return {k: to_builtin(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [to_builtin(v) for v in obj] + elif isinstance(obj, ListConfig): + return [to_builtin(v) for v in obj] + elif isinstance(obj, DictConfig): + return {k: to_builtin(v) for k, v in obj.items()} + else: + return obj + + cfg = asdict(self.hparams.config) + cfg = to_builtin(cfg) + for k, v in cfg.items(): + if v is ListConfig: + print("here") + cfg[k] = OmegaConf.to_container(v, resolve=True) + Path(dir).mkdir(exist_ok=True) + model_path = Path(dir) / model_filename + save_file(self.model.state_dict(), model_path) + with open(Path(dir) / config_filename, "w") as f: + json.dump(cfg, f, indent=2) + + def step(self, batch, batch_idx: int, validation: bool = False): + text_token = batch["text_token"] + audio_token = batch["audio_token"].squeeze(2) + crossatt_mask = batch.get("crossatt_mask") + text_rel_pos = batch.get("text_rel_pos") + encoder_mask = batch.get("encoder_mask") + stop_token = batch.get("stop_token") + text_stop_token = batch.get("text_stop_token") + crossatt_rel_pos = batch.get("crossatt_rel_pos") + logits_mask = batch.get("y_mask") + + pre_logits = self.model( + text_ids=text_token, + audio_inputs=audio_token, + text_mask=encoder_mask, + audio_mask=logits_mask, + crossatt_mask=crossatt_mask, + crossatt_rel_pos=crossatt_rel_pos, + stop_tokens=stop_token, + text_rel_pos=text_rel_pos, + text_stop_tokens=text_stop_token, + ) + losses = {} + if validation and type(self.model.prediction_head) is DiffusionHead: + # deterministic time conditioning during validation + t = ( + torch.ones(pre_logits.shape[0], device=pre_logits.device) + * batch_idx + / self.trainer.num_val_batches[0] + ) + losses |= self.model.prediction_head.compute_loss( + pre_logits, + audio_token[:, 1:], + mask=logits_mask[:, 1:] if logits_mask is not None else None, + t=t, + ) + else: + losses |= self.model.prediction_head.compute_loss( + pre_logits, + audio_token[:, 1:], + mask=logits_mask[:, 1:] if logits_mask is not None else None, + ) + + if self.model.stop_prediction_head is not None and logits_mask is not None: + if stop_token is None: + stop_token = nn.functional.pad( + (~logits_mask)[:, 2:].to(pre_logits), (0, 1) + ) + else: + stop_token = stop_token[:, 1:] + mask = logits_mask[:, 1:] + losses |= self.model.stop_prediction_head.compute_loss( + pre_logits[mask], + stop_token[mask], + ) + + return losses + + def training_step(self, batch, idx): + losses = self.step(batch, idx) + total_loss = 0.0 + for name, loss in losses.items(): + self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True) + if "stop" in name: + if self.hparams.stop_loss_warmup is not None: + alpha, beta = self.hparams.stop_loss_warmup + warmup = np.clip((idx - alpha) / beta, a_min=0.0, a_max=1.0) + else: + warmup = 1.0 + loss *= self.stop_loss_factor * warmup + total_loss += loss + self.log("train_loss", total_loss, prog_bar=True, sync_dist=True) + return total_loss + + def validation_step(self, batch, idx): + losses = self.step(batch, idx, validation=True) + total_loss = 0.0 + for name, loss in losses.items(): + self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True) + total_loss += loss + self.log("val_loss", total_loss, prog_bar=True, sync_dist=True) + return total_loss + + def configure_optimizers(self): + params = [ + { + "params": self.model.parameters(), + "weight_decay": self.weight_decay, + } + ] + opt = torch.optim.AdamW( + params, + lr=self.learning_rate, + betas=self.betas, + ) + # scheduler = get_cosine_schedule_with_warmup( + # opt, + # num_warmup_steps=self.n_warmup_steps, + # num_training_steps=self.n_training_steps, + # ) + scheduler = LambdaLR( + opt, + lr_lambda=cosine_schedule_with_warmup( + warmup_steps=self.hparams.n_warmup_steps, + total_steps=self.hparams.n_training_steps, + start_lr=self.hparams.learning_rate, + end_lr=self.hparams.learning_rate * 0.1, + ), + ) + return [opt], [{"scheduler": scheduler, "interval": "step"}] + + +@hydra.main(config_path="hydra_configs/", config_name="config", version_base="1.3") +def main(cfg: DictConfig): + ptl.seed_everything(cfg.seed_everything) + + model = hydra.utils.instantiate(cfg.model) + cfg.experiment_name = f"ARTTS_{model.hparams.config.decoder_cfg.name}" + datamodule = hydra.utils.instantiate(cfg.data) + trainer = hydra.utils.instantiate(cfg.trainer) + + trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path")) + + +if __name__ == "__main__": + main() diff --git a/tts/tts.py b/tts/tts.py new file mode 100644 index 0000000000000000000000000000000000000000..b082b523ad655f465f227d6b8d4162ccfcca7366 --- /dev/null +++ b/tts/tts.py @@ -0,0 +1,497 @@ +import json +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from safetensors.torch import load_file +from torch import nn +from tqdm import tqdm + +from tts.model.config import TTSConfig +from tts.model.prediction_head import (ContinuousHead, LogitsHead, + StopPredictionHead, VelocityHead) +from tts.model.registry import DECODER_REGISTRY, ENCODER_REGISTRY +from tts.tools import path_matrix, widen_alignment + + +def collect_heads(cache, selected_heads, last=True): + if last: + return torch.cat( + [ + cache[layer]["crossatt_weights"][:, [head], -1] + for layer, head in selected_heads + ], + dim=1, + )[:, :, None] + else: + return torch.cat( + [ + cache[layer]["crossatt_weights"][:, [head]] + for layer, head in selected_heads + ], + dim=1, + ) + + +def mask_from_abs_pos(abs_pos, text_len, expand_factor, width=(5, 1)): + exp_ca_mask = path_matrix(abs_pos, text_len) + exp_ca_mask = widen_alignment(exp_ca_mask, width=width, axis="S") + exp_ca_mask = expand(exp_ca_mask, expand_factor) + return exp_ca_mask + + +def expand(x, r): + b, n, d = x.shape + x = x.unsqueeze(-1).repeat(1, 1, 1, r).reshape(b, n, r * d) + return x + + +class ARTTSModel(nn.Module): + def __init__(self, cfg: TTSConfig): + super().__init__() + self.text_embd = nn.Embedding(cfg.text_vocab_size, cfg.dim) + if cfg.audio_input_type == "discrete": + self.audio_embd = nn.Embedding(cfg.audio_vocab_size, cfg.dim) + self.prediction_head = LogitsHead(cfg.decoder_cfg.dim, cfg.audio_vocab_size) + elif cfg.audio_input_type == "continuous" and cfg.continuous_diffusion: + self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim) + self.prediction_head = VelocityHead( + cfg.decoder_cfg.dim, + cfg.audio_embed_size, + cfg.diffusion_head_num_layers, + ) + elif cfg.audio_input_type == "continuous": + self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim) + self.prediction_head = ContinuousHead( + cfg.decoder_cfg.dim, + cfg.audio_embed_size, + ) + + self.text_encoder = ENCODER_REGISTRY[cfg.encoder_cfg.name](cfg.encoder_cfg) + self.audio_decoder = DECODER_REGISTRY[cfg.decoder_cfg.name](cfg.decoder_cfg) + + self.stop_token_embd = None + self.stop_prediction_head = None + + if cfg.stop_prediction_head: + if cfg.stop_token_embd: + self.stop_token_embd = nn.Embedding(2, cfg.dim, padding_idx=0) + self.stop_prediction_head = StopPredictionHead(cfg.dim) + + if cfg.num_sink_tokens > 0: + self.sink_tokens = nn.Parameter( + torch.randn(cfg.num_sink_tokens, cfg.dim) * 0.02, requires_grad=True + ) + else: + self.sink_tokens = None + + self.disabled_crossatt_head_idx = cfg.disabled_crossatt_head_idx + + @property + def num_sink_tokens(self): + if self.sink_tokens is None: + return 0 + else: + n_sink, _ = self.sink_tokens.shape + return n_sink + + @classmethod + def instantiate_from_config(cls, config): + for k in config.keys(): + if k == "decoder_cfg": + config[k] = DECODER_REGISTRY[config[k]["name"]].config(**config[k]) + if k == "encoder_cfg": + config[k] = ENCODER_REGISTRY[config[k]["name"]].config(**config[k]) + config = TTSConfig(**config) + return ARTTSModel(config), config + + @classmethod + def from_pretrained_local( + cls, + path: str, + config_filename: str = "config.json", + model_filename: str = "model.st", + device: str = "cpu", + ): + with open(Path(path) / config_filename, "r") as f: + config = json.load(f) + model, config = cls.instantiate_from_config(config) + state_dict = load_file(Path(path) / model_filename, device=device) + model.load_state_dict(state_dict) + return model + + def _get_query(self, x: torch.Tensor, *args): + input_audio_embd = self.audio_embd(x) + return self.audio_decoder._get_query(input_audio_embd, *args) + + def forward( + self, + text_ids: torch.LongTensor, + audio_inputs: torch.Tensor, + text_mask: torch.Tensor | None = None, + audio_mask: torch.Tensor | None = None, + stop_tokens: torch.Tensor | None = None, + text_stop_tokens: torch.Tensor | None = None, + text_rel_pos: torch.Tensor | None = None, + crossatt_mask: torch.Tensor | None = None, + crossatt_rel_pos: torch.Tensor | None = None, + n_first_layers: int | None = None, + cache: Any | None = None, + ): + input_text_embd = self.text_embd(text_ids) + input_audio_embd = self.audio_embd(audio_inputs[:, :-1]) + if self.stop_token_embd is not None: + if stop_tokens is not None: + stop_tokens_embd = self.stop_token_embd(stop_tokens) + input_audio_embd += stop_tokens_embd[:, :-1] + + text_hidden_states = self.text_encoder( + input_text_embd, + mask=text_mask, + text_rel_pos=text_rel_pos, + ) + if self.disabled_crossatt_head_idx is not None and crossatt_mask is not None: + crossatt_mask_list = [] + n_sink, _ = self.sink_tokens.shape + for layer in self.audio_decoder.decoder_layers: + if layer.crossatt is not None: + h = layer.crossatt.heads + crossatt_layer_mask = ( + crossatt_mask.unsqueeze(1).repeat(1, h, 1, 1).clone() + ) + crossatt_layer_mask = torch.nn.functional.pad( + crossatt_layer_mask, + (n_sink, 0), + value=True, + ) + crossatt_mask_list.append(crossatt_layer_mask[:, :, :-1]) + + else: + crossatt_mask_list.append(None) + + for layer, head in self.disabled_crossatt_head_idx: + crossatt_mask_list[layer][:, head, :, n_sink:] = False + + crossatt_mask = crossatt_mask_list + else: + if self.sink_tokens is not None: + n_sink, _ = self.sink_tokens.shape + if crossatt_mask is not None: + crossatt_mask = torch.nn.functional.pad( + crossatt_mask, + (n_sink, 0), + value=True, + ) + crossatt_mask = crossatt_mask[:, :-1] + if self.sink_tokens is not None: + sink_tokens = self.sink_tokens[None, :].repeat( + text_hidden_states.shape[0], 1, 1 + ) + text_hidden_states = torch.cat( + (sink_tokens, text_hidden_states), + dim=1, + ) + + if n_first_layers is not None: + pre_logits = self.audio_decoder.forward_first_n_layers( + text_hidden_states, + input_audio_embd, + n_first_layers, + crossatt_mask=crossatt_mask, + ) + else: + pre_logits = self.audio_decoder( + text_hidden_states, + input_audio_embd, + crossatt_mask=crossatt_mask, + cache=cache, + ) + return pre_logits + + def generate( + self, + text_ids: torch.LongTensor, + prefix: torch.Tensor, + text_mask: torch.Tensor | None = None, + crossatt_mask: torch.Tensor | None = None, + text_rel_pos: torch.LongTensor | None = None, + teacher_force: torch.Tensor | None = None, + unfold_ref: bool = False, + max_seq_len: int = 200, + device: str = "cuda", + sampling_params: dict | None = None, + stop_threshold: float = 0.5, + cache: Any | None = None, + do_not_stop: bool = False, + ): + if sampling_params is None: + sampling_params = {} + if text_ids.ndim == 1: + text_ids = text_ids.unsqueeze(0) + + batch_size = text_ids.shape[0] + + input_text_embd = self.text_embd(text_ids) + text_hidden_states = self.text_encoder( + input_text_embd, + text_rel_pos=text_rel_pos, + mask=text_mask, + ) + prefix_embd = self.audio_embd(prefix) + + if self.sink_tokens is not None: + sink_tokens = self.sink_tokens[None, :].repeat( + text_hidden_states.shape[0], 1, 1 + ) + text_hidden_states = torch.cat( + (sink_tokens, text_hidden_states), + dim=1, + ) + if crossatt_mask is not None: + n_sink, _ = self.sink_tokens.shape + crossatt_mask = torch.nn.functional.pad( + crossatt_mask, + (n_sink, 0), + value=True, + ) + + if cache is None: + cache = self.audio_decoder.init_cache( + max_seq_len + prefix_embd.shape[1], device + ) + + stop_status = torch.zeros(batch_size, device=device).bool() + stop_idx = torch.ones(batch_size, device=device).long()*max_seq_len + + preds = [] + pre_prediction = self.audio_decoder.prefill( + text_hidden_states, + prefix_embd, + cache=cache, + ) + prediction = self.prediction_head.predict( + pre_prediction[:, [-1]], **sampling_params + ) + prediction_embd = self.audio_embd(prediction) + + for i in tqdm(range(max_seq_len)): + pre_prediction = self.audio_decoder.decode_one( + text_hidden_states, + prediction_embd, + cache, + crossatt_mask=crossatt_mask, + ) + if unfold_ref: + pre_prediction, pre_prediction_ref = pre_prediction.chunk(2) + else: + pre_prediction_ref = None + prediction = self.prediction_head.predict(pre_prediction, + pre_prediction_ref=pre_prediction_ref, + **sampling_params,) + prediction_embd = self.audio_embd(prediction) + if unfold_ref: + prediction_embd = prediction_embd.repeat(2, 1, 1) + if teacher_force is not None: + b, n, d = teacher_force.shape + if i < n: + prediction_embd = self.audio_embd(teacher_force[:, [i]]) + + preds.append(prediction) + if self.stop_prediction_head is not None: + stop_pred = self.stop_prediction_head(pre_prediction).squeeze(1,2) + stop_signal = stop_pred > stop_threshold + stop_status += stop_signal + stop_idx[stop_signal * stop_idx > i] = i + if stop_status.prod(): + if self.stop_token_embd is not None: + st_embd = self.stop_token_embd( + torch.ones(1, 1, device=device).int() + ) + prediction_embd += st_embd + if not do_not_stop: + break + else: + print(f"STOP: {i}") + + full_prediction = torch.cat(preds, dim=1) + full_prediction = [x[:stop_idx[i]][None] for i, x in enumerate(full_prediction.unbind())] + + return cache, full_prediction + + +""" + def generate_with_playhead( + self, + text_ids: torch.LongTensor, + prefix: torch.Tensor, + playhead_model: PlayHead, + selected_heads_idx: list[tuple[int, int]], + text_stop_tokens: torch.LongTensor | None = None, + text_mask: torch.Tensor | None = None, + text_rel_pos: torch.LongTensor | None = None, + teacher_force: torch.Tensor | None = None, + max_seq_len: int = 200, + device: str = "cuda", + sampling_params: dict | None = None, + stop_threshold: float = 0.5, + do_not_stop: bool = False, + width: tuple[int, int] = (5, 1), + abs_pos_start: int = 0, + stop_end_distance_threshold: int = 5, + ): + if sampling_params is None: + sampling_params = {} + if text_ids.ndim == 1: + text_ids = text_ids.unsqueeze(0) + + input_text_embd = self.text_embd(text_ids) + if self.text_stop_token_embd is not None: + if text_stop_tokens is not None: + text_stop_tokens_embd = self.text_stop_token_embd(text_stop_tokens) + input_text_embd += text_stop_tokens_embd + text_hidden_states = self.text_encoder( + input_text_embd, + text_rel_pos=text_rel_pos, + mask=text_mask, + ) + prefix_embd = self.audio_embd(prefix) + + text_len = text_hidden_states.shape[1] + + if self.sink_tokens is not None: + sink_tokens = self.sink_tokens[None, :].repeat( + text_hidden_states.shape[0], 1, 1 + ) + text_hidden_states = torch.cat( + (sink_tokens, text_hidden_states), + dim=1, + ) + + cache = self.audio_decoder.init_cache(max_seq_len, device) + preds = [] + pre_prediction = self.audio_decoder.prefill( + text_hidden_states, + prefix_embd, + cache=cache, + ) + text_freqs = None + prediction = self.prediction_head.predict( + pre_prediction[:, [-1]], **sampling_params + ) + prediction_embd = self.audio_embd(prediction) + preds.append(prediction) + + playhead_cache = playhead_model.init_cache() + previous_position = torch.zeros(1, device=device) + abs_pos = torch.ones(1, 1, device=device).long() * abs_pos_start + + selected_heads_frame = collect_heads(cache, selected_heads_idx, last=False) + selected_heads_frame = selected_heads_frame.sum(1).transpose(-1, -2) + + pos_preds = [] + steps = [] + expand_crossatt_mask = [] + + for i in tqdm(range(selected_heads_frame.shape[2])): + pred, step = playhead_model.predict( + selected_heads_frame[..., [i]], + cache=playhead_cache, + previous_position=previous_position, + ) + previous_position = pred + abs_pos += step + + pos_preds.append(pred) + steps.append(step) + exp_ca_mask = mask_from_abs_pos( + abs_pos, + (text_len // playhead_model.avg_pool_stride) + 1, + playhead_model.avg_pool_stride, + width=width, + ) + exp_ca_mask = torch.nn.functional.pad( + exp_ca_mask, (self.num_sink_tokens, 0), value=True + ).bool()[..., : text_len + self.num_sink_tokens] + expand_crossatt_mask.append(exp_ca_mask) + + print("starting at: ", abs_pos.item()) + + # pos_pred, step = playhead_model.predict( + # selected_heads_frame, + # cache=playhead_cache, + # previous_position=previous_position, + # ) + # previous_position = pos_pred[:, [-1]] + # abs_pos += step + # exp_ca_mask = mask_from_abs_pos( + # abs_pos, + # (text_len // playhead_model.avg_pool_stride) + 1, + # playhead_model.avg_pool_stride, + # width=width, + # ) + # expand_crossatt_mask.append(exp_ca_mask) + # steps.append(step) + # pos_preds.append(pos_pred) + + progress_bar = tqdm(range(max_seq_len)) + for i in progress_bar: + pre_prediction = self.audio_decoder.decode_one( + text_hidden_states, + prediction_embd, + cache, + # text_freqs=text_freqs, + crossatt_mask=exp_ca_mask, + ) + prediction = self.prediction_head.predict(pre_prediction, **sampling_params) + prediction_embd = self.audio_embd(prediction) + if teacher_force is not None: + b, n, d = teacher_force.shape + if i < n: + prediction_embd = self.audio_embd(teacher_force[:, [i]]) + + ### PLAYHEAD ======================== + selected_heads_frame = ( + collect_heads(cache, selected_heads_idx).sum(1).transpose(-1, -2) + ) + pos_pred, step = playhead_model.predict( + selected_heads_frame, + cache=playhead_cache, + previous_position=previous_position, + ) + previous_position = pos_pred + abs_pos += step + exp_ca_mask = mask_from_abs_pos( + abs_pos, + (text_len // playhead_model.avg_pool_stride) + 1, + playhead_model.avg_pool_stride, + width=width, + ) + exp_ca_mask = torch.nn.functional.pad( + exp_ca_mask, (self.num_sink_tokens, 0), value=True + ).bool()[..., : text_len + self.num_sink_tokens] + expand_crossatt_mask.append(exp_ca_mask) + steps.append(step) + pos_preds.append(pos_pred) + # ================================= + preds.append(prediction) + if self.stop_prediction_head is not None: + stop_pred = self.stop_prediction_head(pre_prediction) + if stop_pred > stop_threshold: + dist = np.abs( + abs_pos.cpu().item() * playhead_model.avg_pool_stride - text_len + ) + progress_bar.set_postfix( + {"stop": f"pos: {abs_pos.cpu().item()}; dist{dist}"} + ) + if dist < stop_end_distance_threshold and not do_not_stop: + break + + progress_bar.set_postfix({"position": abs_pos.cpu().item()}) + + full_prediction = torch.cat(preds, dim=1) + expand_crossatt_mask = torch.cat(expand_crossatt_mask, dim=1) + print(expand_crossatt_mask.shape) + + return cache, full_prediction, expand_crossatt_mask, steps, pos_preds +""" diff --git a/tts/tts/groupdataset.py b/tts/tts/groupdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a7633636e9c0fc02dfaf501d694315990554bd --- /dev/null +++ b/tts/tts/groupdataset.py @@ -0,0 +1,835 @@ +import math +import random +from abc import ABC, abstractmethod +from bisect import bisect_left +from functools import reduce +from itertools import accumulate +from os import wait +from random import choice, choices, randint, sample +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import pytorch_lightning as ptl +import torch +from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk +from einops import rearrange, repeat +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import BatchSampler, DataLoader, Sampler, SubsetRandomSampler +from transformers import PreTrainedTokenizerFast + +from tools import packmask_2d, pad_2d_sequence, sequence_mask + + +def delay_rvq( + code, + head_token: int = -2, + tail_token: int = -3, +): + q, _ = code.shape + extension = torch.ones((q, q + 1)).tril() * head_token + extension += torch.ones((q + 1, q)).tril(diagonal=-1).T * tail_token + extension = torch.flip(extension, (1,)) + extended_code = torch.cat((code, extension), axis=1) + for i in range(q): + extended_code[i, :] = torch.roll(extended_code[i, :], i + 1) + + return extended_code.long() + + +class BucketSampler(Sampler[List[int]]): + def __init__( + self, + buckets: List[List[int]], + batch_sizes: List[int] | int, + bucket_sampling_weights: List[Tuple[float]] = None, + drop_last: bool = True, + distributed: bool = True, # TODO - implement not distributed as well + sample_bucket: Optional[int] = None, + seed: int = 123, + epoch_seed: bool = True, + ): + if type(batch_sizes) is int: + batch_sizes = [batch_sizes] * len(buckets) + else: + assert len(buckets) == len(batch_sizes) + + if bucket_sampling_weights is not None: + assert len(bucket_sampling_weights) == len(batch_sizes) + self.bucket_sampling_weights = bucket_sampling_weights + self.num_replicas = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + self.buckets = [ + b[self.rank : len(b) - len(b) % self.num_replicas : self.num_replicas] + for b in buckets + ] + self.num_samples = [len(b) // self.num_replicas for b in buckets] + self.batch_sizes = batch_sizes + self.total_sizes = [ + ns // bs for ns, bs in zip(self.num_samples, self.batch_sizes) + ] + self.drop_last = drop_last + self.seed = seed + self.epoch = 0 + self.sample_bucket = sample_bucket + self.epoch_seed = epoch_seed + self.batch_size = batch_sizes[0] + + def set_epoch(self, epoch: int): + self.epoch = epoch + + def __len__(self): + return sum(self.total_sizes) + + def __iter__(self): + generator = torch.Generator() + generator.manual_seed(self.seed + self.epoch * self.epoch_seed + self.rank) + pool = [ + BatchSampler( + SubsetRandomSampler(b, generator=generator), + bs, + drop_last=self.drop_last, + ) + for b, bs in zip(self.buckets, self.batch_sizes) + ] + pool = [iter(b) for b in pool] + weights = ( + [w for w in self.bucket_sampling_weights] + if self.bucket_sampling_weights is not None + else None + ) + while pool: # sample until all buckets are done + idx, bucket = choices(list(enumerate(pool)), weights=weights)[0] + try: + batch = next(bucket) + yield batch + except StopIteration: + pool.pop(idx) # if bucket is done, throw it + if weights is not None: + weights.pop(idx) + + +class DatasetFactory(ABC): + @abstractmethod + def build(self): + pass + + +def to_segments_dataset(dataset): + dataset = dataset.map( + lambda x, y: {"segments": [{"start": 0.0, "end": y.shape[-1] / 75, "text": x}]}, + input_columns=["text", "audio_token"], + ) + return dataset + + +class LibriTTS_SyntheticContinuation(DatasetFactory): + def __init__( + self, path, min_dur=0.0, max_dur=30.0, max_cer=0.1, llm_long=True, repeat=1 + ): + self.path = path + self.min_dur = min_dur + self.max_dur = max_dur + self.max_cer = max_cer + self.llm_long = llm_long + + def build(self): + ds = load_from_disk(self.path) + ds = ds.filter(lambda x: x < self.max_cer, input_columns="cer") + ds = ds.map( + lambda x: {"audio_duration": len(x[0][0]) / 75}, input_columns="audio_token" + ) + ds = ds.filter( + lambda x: x < self.max_dur and x > self.min_dur, + input_columns="audio_duration", + ).select_columns(["segments", "audio_token", "audio_duration"]) + if self.llm_long: + llm_long = ( + LibriLight_Medium_Long_WavTokenizer() + .build()["train"] + .shuffle() + .take(100000) + ) + llm_long = llm_long.map( + lambda x: { + "segments": [ + {k: v[0] if type(v) is list else v for k, v in xx.items()} + for xx in x + ] + }, + input_columns="segments", + ) + # splits = ["train.clean.100", "train.clean.360", "train.other.500"] + # ltts = load_dataset("theodorr/lttsr_wavtokenizer", split=splits) + load_dataset("theodorr/ltts_wavtokenizer", split=splits) + # ltts = [x.map(lambda x: {"audio_duration": len(x[0][0])/75}, input_columns="audio_token").select_columns(["text_normalized", "audio_token", "audio_duration"]).rename_column("text_normalized", "text").filter(lambda dur: dur > self.min_dur and dur < self.max_dur, input_columns="audio_duration") for x in ltts] + # mls10k_nemo = load_dataset("theodorr/mls10k_nemo", split="train").select_columns(["text", "audio_duration"]) + # mls10k_wavtokenizer = load_dataset("theodorr/mls10k_wavtokenizer", split="train").select_columns(["audio_token"]) + # mls = concatenate_datasets([mls10k_nemo, mls10k_wavtokenizer], axis=1).shuffle().take(100000) + ds = concatenate_datasets([ds, llm_long], axis=0).select_columns( + ["segments", "audio_token", "audio_duration"] + ) + + # side_datasets = concatenate_datasets(ltts, axis=0) + # side_datasets = to_segments_dataset(side_datasets) + + ds = DatasetDict({"train": ds}) + + return ds + + +class LibriLight_Medium_Long_WavTokenizer(DatasetFactory): + def __init__(self, min_dur=40.0, max_dur=60.0, min_segments=2): + self.min_dur = min_dur + self.max_dur = max_dur + self.min_segments = min_segments + + def build(self): + ds = load_dataset("theodorr/llm_nemo_wavtokenizer_long") + ds = ds.filter( + lambda x: x < self.max_dur and x > self.min_dur, + input_columns="audio_duration", + ) + ds = ds.filter(lambda x: len(x) >= self.min_segments, input_columns="segments") + return ds + + +class MLS_LTTS_SelfCheckpoint_WavTokenizer(DatasetFactory): + def __init__(self, path, min_dur=3.0, max_dur=26.0): + self.min_dur = min_dur + self.max_dur = max_dur + self.text_dur_path = path + + def build(self): + mls10k_nemo = load_dataset( + "theodorr/mls10k_nemo", split="train" + ).select_columns(["text", "audio_duration"]) + mls10k_wavtokenizer = load_dataset( + "theodorr/mls10k_wavtokenizer", split="train" + ).select_columns(["audio_token"]) + mls10k_text_dur = load_from_disk(self.text_dur_path + "/mls10knemo_align_ds") + mls = concatenate_datasets( + [mls10k_nemo, mls10k_wavtokenizer, mls10k_text_dur], axis=1 + ) + + splits = ["train.clean.100", "train.clean.360", "train.other.500"] + + ltts = load_dataset("theodorr/ltts_wavtokenizer") + lttsr = load_dataset("theodorr/lttsr_wavtokenizer") + + ltts_text_dur = load_from_disk(self.text_dur_path + "/ltts_align_ds") + lttsr_text_dur = load_from_disk(self.text_dur_path + "/lttsr_align_ds") + + ltts = [ + concatenate_datasets([ltts[split], ltts_text_dur[split]], axis=1) + for split in splits + ] + lttsr = [ + concatenate_datasets([lttsr[split], lttsr_text_dur[split]], axis=1) + for split in splits + ] + + ltts = ltts + lttsr + + ltts = [ + x.map( + lambda x: {"audio_duration": len(x[0][0]) / 75}, + input_columns="audio_token", + ) + .select_columns( + [ + "text_normalized", + "audio_token", + "audio_duration", + "text_token_duration", + ] + ) + .rename_column("text_normalized", "text") + .filter( + lambda dur: dur > self.min_dur and dur < self.max_dur, + input_columns="audio_duration", + ) + for x in ltts + ] + + mls = concatenate_datasets([mls] + ltts, axis=0) + + return DatasetDict({"train": mls}) + + +class MLS_SelfCheckpoint_WavTokenizer(DatasetFactory): + def __init__(self, path, min_dur=3.0, max_dur=26.0): + self.min_dur = min_dur + self.max_dur = max_dur + self.text_dur_path = path + + def build(self): + mls10k_nemo = load_dataset( + "theodorr/mls10k_nemo", split="train" + ).select_columns(["text", "audio_duration"]) + mls10k_wavtokenizer = load_dataset( + "theodorr/mls10k_wavtokenizer", split="train" + ).select_columns(["audio_token"]) + mls10k_text_dur = load_from_disk(self.text_dur_path) + mls = concatenate_datasets( + [mls10k_nemo, mls10k_wavtokenizer, mls10k_text_dur], axis=1 + ) + return DatasetDict({"train": mls}) + + +class Synthetic_Expresso(DatasetFactory): + def __init__( + self, + path, + min_dur=0.0, + max_dur=23.0, + keep_ltts=False, + cfg_scale: list[float] = None, + ): + self.min_dur = min_dur + self.max_dur = max_dur + self.path = path + self.keep_ltts = keep_ltts + self.cfg_scale = cfg_scale + + def build(self): + ds = load_from_disk(self.path) + if self.cfg_scale is not None: + self.cfg_scale.sort() + print(self.cfg_scale) + ds = ds.map( + lambda x: {"abs_style_intensity": bisect_left(self.cfg_scale, x)}, + input_columns="cfg", + ) + ds = ds.filter( + lambda x: x < self.max_dur and x > self.min_dur, + input_columns="audio_duration", + ) + if self.keep_ltts: + splits = ["train.clean.360"] + ltts = load_dataset( + "theodorr/lttsr_wavtokenizer", split=splits + ) + load_dataset("theodorr/ltts_wavtokenizer", split=splits) + ltts = [ + x.map( + lambda x: {"audio_duration": len(x[0][0]) / 75}, + input_columns="audio_token", + ) + .select_columns(["text_normalized", "audio_token", "audio_duration"]) + .rename_column("text_normalized", "text") + .filter( + lambda dur: dur > self.min_dur and dur < self.max_dur, + input_columns="audio_duration", + ) + for x in ltts + ] + ds = concatenate_datasets([ds, *ltts], axis=0) + + return {"train": ds} + + +class HiFiTTS2_AudioLatent(DatasetFactory): + def __init__( + self, + path: str = "hifitts2_vae8_dataset", + min_dur: float = 3.0, + max_dur: float = 20.1, + framerate: float = 25.0, + ): + self.min_dur = min_dur + self.max_dur = max_dur + self.path = path + self.framerate = framerate + + def build(self): + dataset = load_from_disk(self.path).with_format("torch") + dataset = dataset.map( + lambda x: {"audio_duration": x.shape[1] / self.framerate}, + input_columns="audio_latent", + ).filter( + lambda dur: dur > self.min_dur and dur < self.max_dur, + input_columns="audio_duration", + ) + return DatasetDict({"train": dataset}) + + +class MLS_LibriTTS_WavTokenizer(DatasetFactory): + def __init__(self, min_dur=3.0, max_dur=20.1): + self.min_dur = min_dur + self.max_dur = max_dur + + def build(self): + splits = ["train.clean.100", "train.clean.360", "train.other.500"] + ltts = load_dataset("theodorr/lttsr_wavtokenizer", split=splits) + lttsr = load_dataset("theodorr/ltts_wavtokenizer", split=splits) + ltts += lttsr + ltts = [ + x.map( + lambda x: {"audio_duration": len(x[0][0]) / 75}, + input_columns="audio_token", + ) + .select_columns(["text_normalized", "audio_token", "audio_duration"]) + .rename_column("text_normalized", "text") + .filter( + lambda dur: dur > self.min_dur and dur < self.max_dur, + input_columns="audio_duration", + ) + for x in ltts + ] + mls10k_nemo = load_dataset( + "theodorr/mls10k_nemo", split="train" + ).select_columns(["text", "audio_duration"]) + mls10k_wavtokenizer = load_dataset( + "theodorr/mls10k_wavtokenizer", split="train" + ).select_columns(["audio_token"]) + mls = concatenate_datasets([mls10k_nemo, mls10k_wavtokenizer], axis=1) + return DatasetDict({"train": concatenate_datasets([mls] + ltts, axis=0)}) + + +# def standalone_collate(batch, tokenizer): +# audio_token, text = zip(*[(x["audio_token"], x["text"]) for x in batch]) +# #audio_token = [ +# # delay_rvq( +# # x.squeeze() + 3, +# # head_token=1, +# # tail_token=2, +# # ).T +# # for x in audio_token +# # ] +# text_token = [torch.LongTensor(tokenizer.encode("[BOS]" + x + "[EOS]")) for x in text] +# xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_token)) +# x_mask, y_mask = map(lambda x: sequence_mask(x, device="cpu"), (torch.tensor(xlen), torch.tensor(ylen))) +# +# audio_token, text_token = map(lambda x: pad_sequence(x, batch_first=True, padding_value=0), (audio_token, text_token)) +# +# encoder_mask = (x_mask.unsqueeze(1) * x_mask.unsqueeze(2)) +# crossatt_mask = (x_mask.unsqueeze(1) * y_mask.unsqueeze(2)) +# +# +# +# return { +# "text_token": text_token, +# "audio_token": audio_token, +# "crossatt_mask": crossatt_mask, +# "encoder_mask": encoder_mask, +# "y_mask": y_mask, +# "x_len": xlen, +# "y_len": ylen, +# } +def standalone_collate_latent(batch, tokenizer, abs_style_intensity=False): + audio_latent, text = zip(*[(x["audio_latent"], x["text"]) for x in batch]) + audio_latent = [x.squeeze() for x in audio_latent] + text_token = [ + torch.LongTensor(tokenizer.encode("[BOS]" + x + "[EOS]")) for x in text + ] + xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent)) + x_mask, y_mask = map( + lambda x: sequence_mask(x, device="cpu"), + (torch.tensor(xlen), torch.tensor(ylen)), + ) + + audio_latent, text_token = map( + lambda x: pad_sequence(x, batch_first=True, padding_value=0.0), + (audio_latent, text_token), + ) + + encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2) + crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2) + + if abs_style_intensity: + abs_style_intensity = [x["abs_style_intensity"] for x in batch] + abs_style_intensity = [ + torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity + ] + abs_style_intensity = torch.stack(abs_style_intensity) + else: + abs_style_intensity = None + + return { + "text_token": text_token, + "audio_token": audio_latent, + "crossatt_mask": crossatt_mask, + "encoder_mask": encoder_mask, + "y_mask": y_mask, + "x_len": xlen, + "y_len": ylen, + "abs_style_intensity": abs_style_intensity, + "text_rel_pos": None, + "audio_rel_pos": None, + "crossatt_pos": None, + } + + +def standalone_collate(batch, tokenizer, abs_style_intensity=False): + audio_token, text = zip(*[(x["audio_token"], x["text"]) for x in batch]) + + audio_token = [ + torch.cat( + ( + torch.ones(*a_t.shape[:-1], 1), + a_t + 3, + torch.ones(*a_t.shape[:-1], 1) * 2, + ), + dim=-1, + ) + .squeeze(0) + .transpose(-1, -2) + .long() + for a_t in audio_token + ] + + text_token = [ + torch.LongTensor(tokenizer.encode("[BOS]" + x + "[EOS]")) for x in text + ] + xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_token)) + x_mask, y_mask = map( + lambda x: sequence_mask(x, device="cpu"), + (torch.tensor(xlen), torch.tensor(ylen)), + ) + + audio_token, text_token = map( + lambda x: pad_sequence(x, batch_first=True, padding_value=0), + (audio_token, text_token), + ) + + encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2) + crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2) + + if abs_style_intensity: + abs_style_intensity = [x["abs_style_intensity"] for x in batch] + abs_style_intensity = [ + torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity + ] + abs_style_intensity = torch.stack(abs_style_intensity) + else: + abs_style_intensity = None + + return { + "text_token": text_token, + "audio_token": audio_token, + "crossatt_mask": crossatt_mask, + "encoder_mask": encoder_mask, + "y_mask": y_mask, + "x_len": xlen, + "y_len": ylen, + "abs_style_intensity": abs_style_intensity, + "text_rel_pos": None, + "audio_rel_pos": None, + "crossatt_pos": None, + } + + +def random_log_breakpoints(seq: Sequence, a: int, b: int) -> List[int]: + """ + Generate random breakpoints in a sequence where the gap X between + successive breakpoints satisfies log2(X) ~ Uniform[log2(a), log2(b)]. + Gaps are then rounded to the nearest integer in [a, b]. + + Parameters + ---------- + seq : Sequence + The input sequence in which to place breakpoints. + a : int + Minimum gap (>= 1). + b : int + Maximum gap (>= a). + + Returns + ------- + List[int] + Sorted list of breakpoint indices (0 < idx < len(seq)). + """ + if a < 1 or b < a: + raise ValueError("Require 1 <= a <= b") + n = len(seq) + breakpoints: List[int] = [] + pos = 0 + + while True: + # sample U ~ Uniform(log2(a), log2(b)) + u = random.uniform(math.log2(a), math.log2(b)) + # map back to X = 2^U, then round to nearest integer + x = 2**u + gap = int(math.floor(x + 0.5)) + + # enforce integer bounds exactly + gap = max(a, min(b, gap)) + + pos += gap + if pos >= n: + break + breakpoints.append(pos) + + return breakpoints + + +def random_segments_from_text_and_durations(text_dur, max_pivots, codec_rate_hz): + text, dur = map(list, zip(*text_dur)) + # k = choice(range(1, max_pivots)) + # b = sample(range(2, len(text) - 2), k=max(min(k, len(text) - 5), 1)) + # b.sort() + b = random_log_breakpoints(text, 4, 384) + bounds = [0] + b + [len(text)] + cum_dur = list(accumulate([0] + dur, int.__add__)) + segs, durs = [], [] + for a, b in zip(bounds[:-1], bounds[1:]): + segs.append(text[a:b]) + durs.append(sum(dur[a:b])) + bounds = [0] + list(accumulate(durs, int.__add__)) + segs_dicts = [] + for t, s, e in zip(segs, bounds[:-1], bounds[1:]): + segs_dicts.append( + { + "start": s, + "end": e, + "text_token": t, + } + ) + segs_dicts[-1]["end"] += 1 + return segs_dicts + + +def segments_collate( + batch, + tokenizer, + codec_rate_hz=75, + block_crossatt_mask=True, + text_rel_pos=True, + stop_token=2, + from_token_durations=True, + max_segements=10, +): + # for x in batch: + # if not x.has_key("segments"): + # if x.has_key("alignments"): + + # pass + # else: + # x["segments"] = [{"start": 0., "end": x["audio_token"].shape[-1]/codec_rate_hz, "text": x["text"]}] + + segments = "text_token_duration" if from_token_durations else "segments" + audio_token, segments = zip(*[(x["audio_token"], x[segments]) for x in batch]) + if from_token_durations: + segments = [ + random_segments_from_text_and_durations(x.tolist(), 8, codec_rate_hz) + for x in segments + ] + + bos, eos = map(tokenizer.encode, ("[BOS]", "[EOS]")) + audio_segments = [] + text_segments = [] + audio_segments_len = [] + text_segments_len = [] + + for aud, seg in zip(audio_token, segments): + tt, at, tt_l, at_l = [], [], [], [] + + for i, s in enumerate(seg): + if from_token_durations: + ttoken = bos + s["text_token"] + eos + else: + ttoken = tokenizer.encode("[BOS]" + s["text"] + "[EOS]") + tt.append(ttoken) + sr = codec_rate_hz + a_s = aud[..., int(s["start"]) : int(s["end"])] + 3 + # if i < len(seg) - 1: + a_s = torch.cat((a_s, torch.ones(*a_s.shape[:-1], 1) * stop_token), dim=-1) + at.append(a_s) + at_l.append(a_s.shape[-1]) + tt_l.append(len(ttoken)) + + # at_l[0], at_l[-1] = at_l[0] + 1, at_l[-1] + 1 + at_l[0] = at_l[0] + 1 + + audio_segments.append(at) + text_segments.append(tt) + audio_segments_len.append(at_l) + text_segments_len.append(tt_l) + + text_token = [torch.LongTensor(reduce(list.__add__, x)) for x in text_segments] + audio_token = [ + torch.cat( + ( + torch.ones(*a_s.shape[:-1], 1), + *a_ss, + # torch.ones(*a_s.shape[:-1], 1) * stop_token, + ), + dim=-1, + ) + .squeeze(0) + .transpose(-1, -2) + .long() + for a_ss in audio_segments + ] + + xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_token)) + x_mask, y_mask = map( + lambda x: sequence_mask(x, device="cpu"), + (torch.tensor(xlen), torch.tensor(ylen)), + ) + + audio_token, text_token = map( + lambda x: pad_sequence(x, batch_first=True, padding_value=0), + (audio_token, text_token), + ) + + encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2) + if block_crossatt_mask: + crossatt_mask = [ + packmask_2d(x, y) for x, y in zip(audio_segments_len, text_segments_len) + ] + crossatt_mask = pad_2d_sequence(crossatt_mask) + pad_mask = rearrange(torch.arange(max(ylen)), "n -> 1 n 1") >= rearrange( + torch.tensor(ylen), "n -> n 1 1" + ) + crossatt_mask += pad_mask + else: + crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2) + + text_rel_pos = pad_sequence( + [torch.cat([torch.arange(x) for x in tsl]) for tsl in text_segments_len], + batch_first=True, + ) + audio_rel_pos = pad_sequence( + [torch.cat([torch.arange(x) for x in asl]) for asl in audio_segments_len], + batch_first=True, + ) + + return { + "text_token": text_token, + "segments": segments, + "audio_token": audio_token, + "crossatt_mask": crossatt_mask, + "text_rel_pos": text_rel_pos, + "abs_style_intensity": None, + "audio_rel_pos": audio_rel_pos, + "crossatt_pos": None, + "encoder_mask": encoder_mask, + "y_mask": y_mask, + "x_len": xlen, + "y_len": ylen, + } + + +class LinaDataModule(ptl.LightningDataModule): + def __init__( + self, + path: str | DatasetFactory, + quant_layer: list[int], + train_batch_size: int = 8, + token_by_batch: int | None = None, + n_buckets=5, + codec_rate_hz: int = 75, + num_workers: int = 8, + test_size: int = 2000, + val_batch_size: int = 8, + seed: int = 123, + train_test_seed: int = 123, + segments: bool = False, + tokenizer_file=None, + trail_end_frame: int | None = None, + split="train", + add_columns: str | list[str] | None = None, + add_text_tokens: list[str] | None = None, + type: str = "latent", + ): + super().__init__() + + self.path = path + self.codec_rate_hz = codec_rate_hz + self.num_workers = num_workers + self.quant_layer = quant_layer + self.seed = seed + self.segments = segments + self.train_test_seed = train_test_seed + self.test_size = test_size + self.val_batch_size = val_batch_size + self.train_batch_size = train_batch_size + self.split = split + self.trail_end_frame = trail_end_frame + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) + if add_text_tokens: + self.tokenizer.add_tokens(add_text_tokens) + self.add_columns = add_columns + self.n_buckets = n_buckets + self.token_by_batch = token_by_batch + self.type = type + + def setup(self, stage): + if isinstance(self.path, DatasetFactory): + self.dataset = self.path.build() + else: + self.dataset = load_dataset(self.path) + + split = self.split + columns = [ + "audio_latent" if self.type == "latent" else "audio_token", + "text", + "audio_duration", + ] + if self.add_columns is not None: + if type(self.add_columns) is str: + self.add_columns = [self.add_columns] + columns += self.add_columns + + if self.segments: + # columns += ["segments"] + columns += ["text_token_duration"] + columns.remove("text") + self.collate_fn = lambda x: segments_collate(x, self.tokenizer) + else: + self.collate_fn = lambda x: standalone_collate( + x, self.tokenizer, abs_style_intensity="abs_style_intensity" in columns + ) + self.dataset = ( + self.dataset[split] + .train_test_split(test_size=self.test_size, seed=self.train_test_seed) + .select_columns(columns) + ) + + if self.type == "latent": + self.collate_fn = lambda x: standalone_collate_latent(x, self.tokenizer) + + def get_buckets_by_quantile(duration, n_quantile): + idxdur = list(enumerate(duration)) + idxdur.sort(key=lambda x: x[1]) + idx, dur = zip(*idxdur) + bucket_size = len(idx) // n_quantile + buckets = [list(x) for x in zip(*[iter(idx)] * bucket_size)] + return buckets + + if self.token_by_batch is not None: + train_buckets = get_buckets_by_quantile( + self.dataset["train"]["audio_duration"], self.n_buckets + ) + max_audio_durations = [ + self.dataset["train"]["audio_duration"][x[-1]] for x in train_buckets + ] + + batch_sizes = [ + int(self.token_by_batch // (self.codec_rate_hz * ad)) + for ad in max_audio_durations + ] + self.train_batch_sampler = BucketSampler(train_buckets, batch_sizes) + + def train_dataloader(self): + if self.token_by_batch is not None: + return DataLoader( + self.dataset["train"].with_format("torch"), + num_workers=self.num_workers, + collate_fn=self.collate_fn, + batch_sampler=self.train_batch_sampler, + ) + else: + return DataLoader( + self.dataset["train"].with_format("torch"), + num_workers=self.num_workers, + batch_size=self.train_batch_size, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.dataset["test"].with_format("torch"), + batch_size=self.val_batch_size, + num_workers=0, + collate_fn=self.collate_fn, + ) diff --git a/tts/tts/layers/__init__.py b/tts/tts/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tts/tts/layers/attention.py b/tts/tts/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1fdf095bafe0ca44ce825f2b2acfb5ef0596dcff --- /dev/null +++ b/tts/tts/layers/attention.py @@ -0,0 +1,230 @@ +import math +import os +import time +from typing import Literal + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from fla.models.utils import Cache +from torch import nn +from transformers.cache_utils import Cache + + +def apply_causal_sliding_window(mask: torch.Tensor, window_size: int) -> torch.Tensor: + B, H, Q, KV = mask.shape + device = mask.device + + q_idx = torch.arange(Q, device=device).unsqueeze(1) # (Q, 1) + k_idx = torch.arange(KV, device=device).unsqueeze(0) # (1, KV) + + lower_bound = q_idx - (window_size - 1) # (Q, 1), may be negative + allowed_2d = (k_idx <= q_idx) & (k_idx >= lower_bound) # (Q, KV), dtype=torch.bool + + allowed_4d = allowed_2d.unsqueeze(0).unsqueeze(0).expand(B, H, Q, KV) + + orig_dtype = mask.dtype + if mask.dtype != torch.bool: + mask_bool = mask.to(torch.bool) + else: + mask_bool = mask + + new_mask = mask_bool & allowed_4d + + if orig_dtype != torch.bool: + return new_mask.to(orig_dtype) + else: + return new_mask + + +def precompute_freqs_cis( + t: torch.Tensor, + n_elem: int, + base: float = 10000, +) -> torch.Tensor: + freqs = 1.0 / ( + base + ** ( + torch.arange(0, n_elem, 2, device=t.device)[: (n_elem // 2)].float() + / n_elem + ) + ) + freqs = torch.outer(t, freqs) + cache = repeat(freqs, "... d -> ... (d 2)") + return cache + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + out = x * freqs_cis.cos() + rotate_half(x) * freqs_cis.sin() + return out + + +def scaled_dot_product_attention(query, key, value, mask=None): + scale_factor = 1 / math.sqrt(query.size(-1)) + attn_weight = query @ key.transpose(-2, -1) * scale_factor + if mask is not None: + attn_weight.masked_fill_(~mask, -torch.finfo(attn_weight.dtype).max) + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value, attn_weight + + +class SelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + layer_idx: int, + is_causal: bool = False, + sliding_window: int | None = None, + ): + super().__init__() + self.qkv = nn.Linear(dim, 3 * dim) + assert dim % num_heads == 0 + self.heads = num_heads + self.is_causal = is_causal + self.layer_idx = layer_idx + self.output_proj = nn.Linear(dim, dim) + self.sliding_window = sliding_window + if self.sliding_window is not None: + self.is_causal = False + + def forward( + self, + x, + freqs: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + cache: Cache | None = None, + ): + B, T, D = x.shape + q, k, v = self.qkv(x).chunk(3, dim=-1) + q, k, v = map( + lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v) + ) + + if freqs is not None: + q = apply_rotary_emb(q, freqs) + k = apply_rotary_emb(k, freqs) + if cache is not None: + cache.update(attn_state=(k, v), layer_idx=self.layer_idx) + k, v = cache[self.layer_idx]["attn_state"] + + if self.sliding_window is not None: + mask = torch.ones(B, 1, T, T, device=x.device) + mask = apply_causal_sliding_window(mask, self.sliding_window) + + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, is_causal=self.is_causal and T > 1 + ) + y = rearrange(y, "b h n d -> b n (h d)") + y = self.output_proj(y) + return y + + +class CrossAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + layer_idx: int | None = None, + dropout: float = 0.1, + ): + super().__init__() + assert dim % num_heads == 0 + self.pre_norm_q = nn.LayerNorm(dim) + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.layer_idx = layer_idx + self.heads = num_heads + self.dropout_att = dropout + + def _prepare_kv(self, text_hidden_states: torch.Tensor): + v = self.ln_v(self.v(text_hidden_states)) + k = self.ln_k(self.k(text_hidden_states)) + + def _query(self, x): + return self.q(self.pre_norm_q(q)) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor | None = None, + v: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + output_attention: bool = False, + cache: Cache | None = None, + **kwargs, + ): + if v is None: + v = k + q = self.q(self.pre_norm_q(q)) + if cache is not None: + if cache[self.layer_idx] is not None: + ca_state = cache[self.layer_idx]["crossatt_state"] + if ca_state is not None: + k, v = ca_state + else: + v = self.v(v) + k = self.k(k) + cache.update(crossatt_state=(k, v), layer_idx=self.layer_idx) + else: + v = self.v(v) + k = self.k(k) + q, k, v = map( + lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v) + ) + if not self.training: + x, att = scaled_dot_product_attention( + q, k, v, mask=mask[:, None] if mask is not None else None + ) + else: + x = nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask[:, None], dropout_p=self.dropout_att + ) + att = None + x = rearrange(x, "b h n d -> b n (h d)") + if att is not None: + if cache is not None: + cache.update(crossatt_weights=att, layer_idx=self.layer_idx) + else: + self.att = att + return x + + +class ConvPos(nn.Module): + def __init__(self, dim, max_seq_len=2000, kernel_size=7): + super().__init__() + self.embed = nn.Embedding(max_seq_len, dim) + self.dw_conv = nn.Conv1d(dim, dim, kernel_size, groups=dim, padding="same") + + def forward(self, x, left_shift=0): + # left_pad = 31 if left_shift > 0 else 0 + # x = torch.cat((torch.arange(left_shift - left_pad, left_shift).to(x).unsqueeze(0),x, torch.arange(31).to(x).unsqueeze(0)), dim=1).clamp_min_(0) + y = self.embed(x) + y = rearrange(y, "b n c -> b c n") + y = self.dw_conv(y) + y = rearrange(y, "b c n -> b n c") # [:,left_pad:-31] + return y + + +class SinPos(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + exp = torch.arange(self.dim // 2, device=x.device) + exp = 2 * exp / (self.dim) + exp = rearrange(exp, "e -> 1 1 e") + x = rearrange(x, "b p -> b p 1") + pos = x * torch.pow(10000, -exp) + pos = torch.cat((pos, pos + math.pi / 2), dim=2) + pos = torch.sin(pos) + return pos diff --git a/tts/tts/layers/ffn.py b/tts/tts/layers/ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5133855ba99e3b0be9257c3383edbceef49514fd --- /dev/null +++ b/tts/tts/layers/ffn.py @@ -0,0 +1,68 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class SwiGLU(nn.Module): + def __init__(self, d_model: int, ffn_expansion_factor: int = 4): + super().__init__() + self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2) + self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model) + + def forward(self, x): + gate, x = self.p_in(x).chunk(2, dim=-1) + return self.p_out(nn.functional.silu(gate) * x) + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class GaussianFourierTimeEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.weight = nn.Parameter(torch.randn(dim), requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x[:, None] * self.weight[None, :] * 2 * torch.pi + x = torch.cat((torch.sin(x), torch.cos(x)), dim=1) + return x + + +class AdaLNFinalLayer(nn.Module): + def __init__(self, hidden_dim, feature_dim): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_dim, feature_dim, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_dim, 2 * hidden_dim, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class AdaLNMLP(nn.Module): + def __init__(self, hidden_dim): + super().__init__() + self.hidden_dim = hidden_dim + + self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False) + self.mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim, bias=True), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim, bias=True), + ) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_dim, 3 * hidden_dim, bias=True) + ) + + def forward(self, x, y): + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h diff --git a/tts/tts/model/__init__.py b/tts/tts/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..29c6330d7c704d41d9874637cd36ca123a8e51db --- /dev/null +++ b/tts/tts/model/__init__.py @@ -0,0 +1,2 @@ +from .simple_gla import SimpleGLADecoder +from .transformer import TransformerDecoder, TransformerEncoder diff --git a/tts/tts/model/cache.py b/tts/tts/model/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..6fd84372fab1fe59bf02c0a75c55482084558351 --- /dev/null +++ b/tts/tts/model/cache.py @@ -0,0 +1,308 @@ +from typing import Any + +import torch +from transformers.cache_utils import Cache, _static_cache_update + + +class StaticCache(Cache): + """ + Static Cache class to be used with `torch.compile(model)` and `torch.export()`. + + Parameters: + config (`PretrainedConfig`): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. If you are manually setting the batch size, make sure to take into account the + number of beams if you are running beam search + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache + + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + StaticCache() + ``` + """ + + is_compileable = True + + def __init__( + self, + max_batch_size: int, + head_dim: int, + num_key_value_heads: int, + num_hidden_layers: int, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, + layer_device_map: dict[int, str | torch.device | int] | None = None, + ) -> None: + super().__init__() + self.max_batch_size = max_batch_size + self.max_cache_len = max_cache_len + + # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads + self.head_dim = head_dim + self._dtype = dtype + self.num_key_value_heads = num_key_value_heads + self.num_hidden_layers = num_hidden_layers + self.key_cache: list[torch.Tensor] = [] + self.value_cache: list[torch.Tensor] = [] + # Note: There will be significant perf decrease if switching to use 5D tensors instead. + cache_shape = ( + self.max_batch_size, + self.num_key_value_heads, + self.max_cache_len, + self.head_dim, + ) + device = torch.device(device) if device is not None else None + for idx in range(self.num_hidden_layers): + if layer_device_map is not None: + layer_device = layer_device_map[idx] + else: + layer_device = device + new_layer_key_cache = torch.zeros( + cache_shape, dtype=self._dtype, device=layer_device + ) + new_layer_value_cache = torch.zeros( + cache_shape, dtype=self._dtype, device=layer_device + ) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + It is VERY important to index using a tensor, otherwise you introduce a copy to the device. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. + + Return: + A tuple containing the updated key and value states. + """ + if cache_kwargs is None: + cache_kwargs = {} + + key_states = key_states.to(self.key_cache[layer_idx].dtype) + value_states = value_states.to(self.value_cache[layer_idx].dtype) + return _static_cache_update( + self.key_cache[layer_idx], + self.value_cache[layer_idx], + key_states, + value_states, + cache_kwargs.get("cache_position"), + ) + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def get_max_cache_shape(self) -> int | None: + return self.max_cache_len + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + + def get_mask_sizes( + self, cache_position: torch.Tensor, layer_idx: int + ) -> tuple[int, int]: + """ + Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for + the given layer at `layer_idx`. + The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size), + for each layer. + """ + kv_length = self.get_max_cache_shape() + return kv_length, 0 + + +class Cache: + """ + A cache used for storing hidden states produced by flash linear attention models. + + It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`. + """ + + is_compileable = True + + def __init__(self, seen_tokens: int = 0) -> Cache: + super().__init__() + + self.states: list[dict[str, Any]] = [] + + self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen + + def __getitem__(self, layer_idx: int) -> dict[str, Any]: + if layer_idx < len(self): + return self.states[layer_idx] + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + for state in self.states: + yield state + + def __len__(self): + return len(self.states) + + def update( + self, + recurrent_state: torch.Tensor | None = None, + attn_state: tuple[torch.Tensor, torch.Tensor] | None = None, + conv_state: tuple[torch.Tensor] | None = None, + ffn_state: torch.Tensor | None = None, + layer_idx: int = 0, + offset: int | None = 1, + cache_kwargs: dict | None = None, + ): + """ + Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`. + + Args: + recurrent_state (`torch.Tensor`, `optional`): + The new recurrent state to cache. + attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`): + The new attention key/value states to cache. + conv_state (`Tuple[torch.Tensor]`, `optional`): + The new convolution state to cache. + layer_idx (`int`, defaults to 0): + The index of the layer to cache the states for. + offset (`int`, `optional`, defaults to 1): + The number of new tokens being processed. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. + + Return: + Dictionary of the updated state. + """ + + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += offset + + if attn_state is not None: + input_size = attn_state[0].shape[-2] + window_size = cache_kwargs.get("window_size", None) + if not isinstance(attn_state, Tuple) or len(attn_state) != 2: + raise ValueError( + "`attn_state` must be a tuple of two tensors for key/value states" + ) + if len(self.states) <= layer_idx: + if attn_state is not None: + if window_size is not None and input_size > window_size: + attn_state = ( + attn_state[0][..., -window_size:, :].contiguous(), + attn_state[1][..., -window_size:, :].contiguous(), + ) + state = dict( + recurrent_state=recurrent_state, + attn_state=attn_state, + conv_state=conv_state, + ffn_state=ffn_state, + ) + self.states.append(state) + else: + state = self.states[layer_idx] + if recurrent_state is not None: + state["recurrent_state"] = recurrent_state + if attn_state is not None: + key_state, value_state = state["attn_state"] + if window_size is not None and key_state.shape[-2] == window_size: + # DO NOT allocate new memory if the cache is full + # roll the key/value states to the left by `input_size` + key_state = key_state.roll(-input_size, -2) + value_state = value_state.roll(-input_size, -2) + # replace the last `input_size` tokens with the new key/value states + key_state[..., -input_size:, :] = attn_state[0] + value_state[..., -input_size:, :] = attn_state[1] + attn_state = (key_state, value_state) + else: + attn_state = ( + torch.cat([key_state, attn_state[0]], -2), + torch.cat([value_state, attn_state[1]], -2), + ) + state["attn_state"] = attn_state + if conv_state is not None: + state["conv_state"] = conv_state + if ffn_state is not None: + state["ffn_state"] = ffn_state + + return state + + def get_seq_length(self, layer_idx: int | None = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.states) <= layer_idx: + return 0 + return self._seen_tokens + + def get_max_length(self) -> int | None: + """Returns the maximum sequence length of the cached states. Cache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> tuple: + return tuple(self.states) + + @classmethod + @torch.compiler.disable + def from_legacy_cache( + cls, past_key_values: tuple | None = None, seen_tokens: int = 0 + ) -> Cache: + """Converts a cache in the legacy cache format into an equivalent `Cache`.""" + + cache = cls(seen_tokens) + if isinstance(past_key_values, list): + for layer_idx in range(len(past_key_values)): + cache.states.append(past_key_values[layer_idx]) + return cache diff --git a/tts/tts/model/cache_utils.py b/tts/tts/model/cache_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c1aa3d8f27013be13e55b5752673745f2b223c5e --- /dev/null +++ b/tts/tts/model/cache_utils.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import transformers +from transformers.cache_utils import DynamicCache + +from layers.attention import precompute_freqs_cis + + +class FLACache(transformers.cache_utils.Cache): + """ + A cache used for storing hidden states produced by flash linear attention models. + + It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`. + """ + + is_compileable = True + + def __init__( + self, + head_dim: int | None = None, + max_seq_len: int | None = None, + num_states: int | None = None, + device: str = "cuda", + seen_tokens: int = 0, + ): + super().__init__() + if head_dim is not None and max_seq_len is not None: + self.freqs = precompute_freqs_cis( + torch.arange(max_seq_len, device=device), head_dim + ) + self.states: List[Dict[str, Any]] = [] + if num_states is not None: + self.states = [ + dict( + recurrent_state=None, + attn_state=None, + conv_state=None, + short_conv_state=None, + ffn_state=None, + crossatt_state=None, + crossatt_weights=None, + ) + for _ in range(num_states) + ] + + self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen + + def __getitem__(self, layer_idx: int) -> Dict[str, Any]: + if layer_idx < len(self): + return self.states[layer_idx] + else: + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) + + def __iter__(self): + for state in self.states: + yield state + + def __len__(self): + return len(self.states) + + def update( + self, + recurrent_state: torch.Tensor = None, + attn_state: Tuple[torch.Tensor, torch.Tensor] = None, + conv_state: Tuple[torch.Tensor] = None, + short_conv_state: Tuple[torch.Tensor] = None, + crossatt_state: Tuple[torch.Tensor] = None, + crossatt_weights: Tuple[torch.Tensor] = None, + ffn_state: torch.Tensor = None, + layer_idx: int = 0, + offset: Optional[int] = 1, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`. + + Args: + recurrent_state (`torch.Tensor`, `optional`): + The new recurrent state to cache. + attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`): + The new attention key/value states to cache. + conv_state (`Tuple[torch.Tensor]`, `optional`): + The new convolution state to cache. + layer_idx (`int`, defaults to 0): + The index of the layer to cache the states for. + offset (`int`, `optional`, defaults to 1): + The number of new tokens being processed. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. + + Return: + Dictionary of the updated state. + """ + + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += offset + + if attn_state is not None: + input_size = attn_state[0].shape[-2] + window_size = ( + cache_kwargs.get("window_size", None) + if cache_kwargs is not None + else None + ) + if not isinstance(attn_state, Tuple) or len(attn_state) != 2: + raise ValueError( + "`attn_state` must be a tuple of two tensors for key/value states" + ) + if len(self.states) <= layer_idx: + if attn_state is not None: + if window_size is not None and input_size > window_size: + attn_state = ( + attn_state[0][..., -window_size:, :].contiguous(), + attn_state[1][..., -window_size:, :].contiguous(), + ) + state = dict( + recurrent_state=recurrent_state, + attn_state=attn_state, + conv_state=conv_state, + short_conv_state=short_conv_state, + ffn_state=ffn_state, + crossatt_state=crossatt_state, + crossatt_weights=crossatt_weights, + ) + self.states.append(state) + else: + state = self.states[layer_idx] + if recurrent_state is not None: + state["recurrent_state"] = recurrent_state + if crossatt_state is not None: + state["crossatt_state"] = crossatt_state + if crossatt_weights is not None: + if state["crossatt_weights"] is not None: + state["crossatt_weights"] = torch.cat( + (state["crossatt_weights"], crossatt_weights), dim=-2 + ) + else: + state["crossatt_weights"] = crossatt_weights + if attn_state is not None: + key_state, value_state = state["attn_state"] + if window_size is not None and key_state.shape[-2] == window_size: + # DO NOT allocate new memory if the cache is full + # roll the key/value states to the left by `input_size` + key_state = key_state.roll(-input_size, -2) + value_state = value_state.roll(-input_size, -2) + # replace the last `input_size` tokens with the new key/value states + key_state[..., -input_size:, :] = attn_state[0] + value_state[..., -input_size:, :] = attn_state[1] + attn_state = (key_state, value_state) + else: + attn_state = ( + torch.cat([key_state, attn_state[0]], -2), + torch.cat([value_state, attn_state[1]], -2), + ) + state["attn_state"] = attn_state + if conv_state is not None: + state["conv_state"] = conv_state + if short_conv_state is not None: + state["short_conv_state"] = short_conv_state + if ffn_state is not None: + state["ffn_state"] = ffn_state + + return state + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.states) <= layer_idx: + return 0 + return self._seen_tokens + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. Cache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> Tuple: + return tuple(self.states) + + @classmethod + @torch.compiler.disable + def from_legacy_cache( + cls, past_key_values: Optional[Tuple] = None, seen_tokens: int = 0 + ) -> Cache: + """Converts a cache in the legacy cache format into an equivalent `Cache`.""" + + cache = cls(seen_tokens) + if isinstance(past_key_values, list): + for layer_idx in range(len(past_key_values)): + cache.states.append(past_key_values[layer_idx]) + return cache + + +class TransformerDecoderCache(FLACache): + def __init__( + self, + head_dim: int, + max_seq_len: int, + device: str, + ): + super().__init__() + self.freqs = precompute_freqs_cis( + torch.arange(max_seq_len, device=device), head_dim + ) diff --git a/tts/tts/model/config.py b/tts/tts/model/config.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2d9da0840f5ffb37bc43cc05b0e7b5bd52282a --- /dev/null +++ b/tts/tts/model/config.py @@ -0,0 +1,70 @@ +from dataclasses import dataclass, field +from typing import Literal + + +@dataclass +class SimpleGLADecoderConfig: + name: str = "simple_gla" + dim: int = 512 + use_short_conv = True + expand_k: float = 0.5 + expand_v: float = 1.0 + num_heads: int = 4 + num_layers: int = 12 + ffn_expansion_factor: int = 4 + conv_layers: list[int] | None = field(default_factory=lambda: [0, 2, 4, 6, 8, 10]) + blind_crossatt: bool = True + listen_read_crossatt: dict[int, list[Literal["listen", "read"]]] | None = None + crossatt_num_heads: int = 1 + crossatt_dropout: float = 0.1 + crossatt_layer_idx: list[int] = field(default_factory=lambda: [5]) + + +@dataclass +class TransformerEncoderConfig: + name: str = "sa_transformer" + dim: int = 512 + num_heads: int = 8 + num_layers: int = 6 + ffn_expansion_factor: int = 4 + + +@dataclass +class TransformerDecoderConfig: + name: str = "sa_transformer" + dim: int = 512 + num_heads: int = 8 + num_layers: int = 12 + conv_layers: list[int] | None = field(default_factory=lambda: [0, 2, 4, 6, 8, 10]) + ffn_expansion_factor: int = 4 + crossatt_dropout: float = 0.1 + crossatt_num_heads: int = 1 + crossatt_layer_idx: list[int] = field(default_factory=lambda: [5, 6]) + + +DecoderConfig = TransformerDecoderConfig | SimpleGLADecoderConfig +EncoderConfig = TransformerEncoderConfig + + +@dataclass +class TTSConfig: + dim: int = 512 + text_vocab_size: int = 256 + 3 + audio_vocab_size: int = 4096 + 3 + audio_embed_size: int = 16 + audio_input_type: Literal["continuous", "discrete"] = "continuous" + diffusion_head_num_layers: int = 3 + encoder_cfg: EncoderConfig = field(default_factory=TransformerEncoderConfig) + decoder_cfg: DecoderConfig = field(default_factory=TransformerDecoderConfig) + stop_prediction_head: bool = True + continuous_diffusion: bool = True + + +@dataclass +class QueryVCConfig: + dim: int = 512 + semantic_dim: int = 512 + num_layers: int = 12 + lag: int = 4 + audio_embed_size: int = 8 + diffusion_head_num_layers: int = 3 diff --git a/tts/tts/model/prediction_head.py b/tts/tts/model/prediction_head.py new file mode 100644 index 0000000000000000000000000000000000000000..273add83d6a5bd417d1acb93479218682cd56f36 --- /dev/null +++ b/tts/tts/model/prediction_head.py @@ -0,0 +1,259 @@ +import torch +from einops import rearrange +from numpy import zeros_like +from torch import nn +from torchdyn.core import NeuralODE + +from layers.ffn import AdaLNFinalLayer, AdaLNMLP, GaussianFourierTimeEmbedding + + +def net_factory(hidden_dim: int, num_layers: int) -> nn.Module: + pass + + +import os +import sys +from contextlib import contextmanager + +import torch +import torch.nn.functional as F + + +@contextmanager +def suppress_stdout(): + original_stdout = sys.stdout + try: + sys.stdout = open(os.devnull, "w") + yield + finally: + sys.stdout.close() + sys.stdout = original_stdout + + +def sample_from_logits( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, +) -> torch.Tensor: + B, N, C = logits.shape + logits = logits / temperature + + # Apply top-k + if top_k > 0: + top_k = min(top_k, C) + topk_values, _ = torch.topk(logits, top_k, dim=-1) + kth_value = topk_values[..., -1, None] + logits = torch.where( + logits < kth_value, torch.full_like(logits, float("-inf")), logits + ) + + # Apply top-p (nucleus) sampling + if top_p > 0.0 and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + probs = F.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(probs, dim=-1) + + # Create mask for tokens to remove + cutoff_mask = cumulative_probs > top_p + cutoff_mask[..., 0] = 0 # Always keep at least one token + sorted_logits[cutoff_mask] = float("-inf") + + # Map back to original logits shape + logits = torch.full_like(logits, float("-inf")).scatter( + -1, sorted_indices, sorted_logits + ) + + # Convert logits to probabilities + probs = F.softmax(logits, dim=-1) + + # Sample + samples = torch.multinomial(probs.view(-1, C), num_samples=1).view(B, N) + return samples + + +class LogitsHead(nn.Module): + def __init__(self, hidden_dim: int, vocab_size: int): + super().__init__() + self.logits_proj = nn.Linear(hidden_dim, vocab_size) + + def forward(self, pre_logits): + return self.logits_proj(pre_logits) + + def compute_loss( + self, + pre_logits: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, + ): + logits = self(pre_logits) + if mask is not None: + flat_logits = logits[mask] + flat_target = target[mask] + else: + flat_logits = rearrange(logits, "b n l -> (b n) l") + flat_target = rearrange(target, "b n -> (b n)") + + loss = nn.functional.cross_entropy( + flat_logits, + flat_target, + ignore_index=1, + ) + + return {"cross_entropy": loss} + + def predict(self, x: torch.Tensor, *args, **kwargs): + return sample_from_logits(self(x), *args, **kwargs) + + +class ContinuousHead(nn.Module): + def __init__(self, hidden_dim: int, feature_dim: int): + super().__init__() + self.continuous_head = nn.Linear(hidden_dim, feature_dim) + + def forward(self, x: torch.Tensor): + return self.continuous_head(x) + + def predict(self, x: torch.Tensor): + return self(x) + + def compute_loss( + self, pre_logits: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + ): + if mask is not None: + pre_logits = pre_logits[mask] + target = target[mask] + return {"mse": nn.functional.mse_loss(self(pre_logits), target)} + + +class DiffusionHead(nn.Module): + def __init__( + self, + hidden_dim: int, + feature_dim: int, + num_layers: int, + cond_dim: int | None = None, + ): + super().__init__() + cond_dim = cond_dim if cond_dim is not None else hidden_dim + self.feature_embed = nn.Linear(feature_dim, hidden_dim) + self.cond_embed = nn.Linear(cond_dim, hidden_dim) + self.time_embed = GaussianFourierTimeEmbedding(hidden_dim // 2) + self.adaln_mlp = nn.ModuleList( + [AdaLNMLP(hidden_dim) for _ in range(num_layers)] + ) + self.adaln_final_layer = AdaLNFinalLayer(hidden_dim, feature_dim) + self.feature_dim = feature_dim + + def forward( + self, + cond: torch.Tensor, + x: torch.Tensor, + t: torch.Tensor | None = None, + cond_drop_mask: torch.BoolTensor | None = None, + ): + cond = self.cond_embed(cond) + if cond_drop_mask is not None: + cond[cond_drop_mask] = 0.0 + cond += self.time_embed(t)[:, None] + x = self.feature_embed(x) + + for l in self.adaln_mlp: + x = l(x, cond) + y = self.adaln_final_layer(x, cond) + + return y + + def compute_loss( + self, + cond: torch.Tensor, + x1: torch.Tensor, + mask: torch.Tensor | None, + sigma: float = 1e-5, + t: torch.Tensor | None = None, + cfg_drop_rate: float = 0.1, + ): + if t is None: + t = torch.rand(cond.shape[0], device=cond.device) + + x0 = torch.randn_like(x1, device=x1.device) + + flow_target = x1 - (1 - sigma) * x0 + alpha = (1 - (1 - sigma) * t).view(-1, 1, 1) + xt = alpha * x0 + t.view(-1, 1, 1) * x1 + + if self.training and cfg_drop_rate > 0.0: + cond_drop_mask = torch.rand(cond.shape[:2]) < cfg_drop_rate + else: + cond_drop_mask = None + + flow_pred = self(cond, xt, t, cond_drop_mask=cond_drop_mask) + + if mask is not None: + flow_pred = flow_pred[mask] + flow_target = flow_target[mask] + + loss = nn.functional.mse_loss(flow_pred, flow_target) + + return {"diffusion": loss} + + def predict( + self, + pre_prediction: torch.Tensor, + solver: str = "euler", + sensitivity: str = "adjoint", + num_steps: int = 10, + cfg: float = 1.0, + **kwargs, + ): + if cfg == 1.0: + + def solver_fn(t, Xt, *args, **kwargs): + return self(pre_prediction, Xt, t.unsqueeze(0)) + else: + + def solver_fn(t, Xt, *args, **kwargs): + cond_uncond = torch.cat( + (pre_prediction, torch.zeros_like(pre_prediction)), + dim=0, + ) + cond_uncond = self(cond_uncond, Xt.repeat(2, 1, 1), t.unsqueeze(0)) + cond, uncond = cond_uncond.chunk(2, dim=0) + return uncond + cfg * (cond - uncond) + + # get rid of torchdyn warning + with suppress_stdout(): + node_ = NeuralODE(solver_fn, solver=solver, sensitivity=sensitivity) + t_span = torch.linspace(0, 1, num_steps + 1, device=pre_prediction.device) + traj = node_.trajectory( + torch.randn(1, 1, self.feature_dim, device=pre_prediction.device), + t_span=t_span, + ) + prediction = traj[-1] + return prediction + + +class StopPredictionHead(nn.Module): + def __init__(self, dim: int, weight_loss: float = 1.0): + super().__init__() + self.proj = nn.Linear(dim, 1) + self.weight_loss = weight_loss + + def forward(self, pre_prediction: torch.Tensor): + return torch.sigmoid(self.proj(pre_prediction)) + + def predict(self, pre_prediction: torch.Tensor): + return torch.sigmoid(self.proj(pre_prediction)) + + def compute_loss( + self, + pre_prediction: torch.Tensor, + target: torch.Tensor, + ): + logits = self.proj(pre_prediction) + bce = nn.functional.binary_cross_entropy_with_logits( + logits.squeeze(-1), + target, + weight=torch.ones(logits.shape[0], device=logits.device) * self.weight_loss, + ) + return {"stop_bce": bce} diff --git a/tts/tts/model/registry.py b/tts/tts/model/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0a00c29717c2a7888a38cef182b1240068810762 --- /dev/null +++ b/tts/tts/model/registry.py @@ -0,0 +1,18 @@ +ENCODER_REGISTRY = {} +DECODER_REGISTRY = {} + + +def register_encoder(name): + def wrapper(cls): + ENCODER_REGISTRY[name] = cls + return cls + + return wrapper + + +def register_decoder(name): + def wrapper(cls): + DECODER_REGISTRY[name] = cls + return cls + + return wrapper diff --git a/tts/tts/model/shortconv.py b/tts/tts/model/shortconv.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9ed8b14586fa441a27267ffff09717ede15ac7 --- /dev/null +++ b/tts/tts/model/shortconv.py @@ -0,0 +1,39 @@ +import torch +from fla.modules import ShortConvolution +from torch import nn + +from layers.ffn import SwiGLU +from model.cache_utils import FLACache + + +class ShortConvBlock(nn.Module): + def __init__( + self, + dim: int, + kernel_size: int, + ffn_expansion_factor: int, + layer_idx: int, + use_fast_conv1d: bool = True, + ): + super().__init__() + self.tmix = ShortConvolution(dim, kernel_size, use_fast_conv1d=use_fast_conv1d) + self.cmix = SwiGLU(dim, ffn_expansion_factor) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.layer_idx = layer_idx + + def forward(self, x: torch.Tensor, cache: FLACache | None = None, **kwargs): + last_state = None + if cache is not None and len(cache) > self.layer_idx: + last_state = cache[self.layer_idx]["short_conv_state"] + x, short_conv_state = self.tmix( + self.norm1(x), cache=last_state, output_final_state=cache is not None + ) + if cache is not None: + cache.update( + short_conv_state=short_conv_state, + layer_idx=self.layer_idx, + offset=x.shape[1], + ) + x = self.cmix(self.norm2(x)) + return x diff --git a/tts/tts/model/simple_gla.py b/tts/tts/model/simple_gla.py new file mode 100644 index 0000000000000000000000000000000000000000..1bbdb8dab7a28acccee93bd348791834977bfd4b --- /dev/null +++ b/tts/tts/model/simple_gla.py @@ -0,0 +1,248 @@ +import os + +import torch +import torch.nn.functional as F +from einops import rearrange +from fla.layers.simple_gla import SimpleGatedLinearAttention +from fla.models.utils import Cache +from torch import nn + +from layers.attention import ( + BlindCrossAttention, + ConvPos, + CrossAttention, + ListenReadCrossAttention, + precompute_freqs_cis, +) +from layers.ffn import SwiGLU +from model.cache_utils import FLACache +from model.config import SimpleGLADecoderConfig +from model.registry import register_decoder +from model.shortconv import ShortConvBlock + +if "GRAD_CKPT" in os.environ: + + def maybe_grad_ckpt(f): + def grad_ckpt_f(*args, **kwargs): + return torch.utils.checkpoint.checkpoint( + f, *args, **kwargs, use_reentrant=False + ) + + return grad_ckpt_f +else: + + def maybe_grad_ckpt(f): + return f + + +class SimpleGLABlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + layer_idx: int, + expand_k: float, + expand_v: float, + use_short_conv: bool, + ffn_expansion_factor: int, + ): + super().__init__() + self.tmix = SimpleGatedLinearAttention( + hidden_size=dim, + num_heads=num_heads, + layer_idx=layer_idx, + ) + self.cmix = SwiGLU(dim, ffn_expansion_factor) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + def forward( + self, + x, + freqs: torch.Tensor | None = None, + text_freqs: torch.Tensor | None = None, + cache: Cache | None = None, + ): + x = ( + self.tmix( + self.norm1(x), + freqs=freqs, + text_freqs=text_freqs, + past_key_values=cache, + use_cache=cache is not None, + )[0] + + x + ) + x = self.cmix(self.norm2(x)) + x + return x + + +class DecoderBlockWithOptionalCrossAttention(nn.Module): + def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None): + super().__init__() + + self.decoder_block = decoder_block + self.crossatt = crossatt + + def forward( + self, + x: torch.Tensor, + encoder_output: torch.Tensor | None = None, + freqs: torch.Tensor | None = None, + text_freqs: torch.Tensor | None = None, + cache: Cache | None = None, + selfatt_mask: torch.Tensor | None = None, + crossatt_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + x = self.decoder_block( + x, + freqs=freqs, + cache=cache, + ) + + if self.crossatt is not None: + x = x + self.crossatt( + x, + k=encoder_output, + text_freqs=text_freqs, + mask=crossatt_mask, + cache=cache, + ) + + return x + + +@register_decoder("simple_gla") +class SimpleGLADecoder(nn.Module): + config = SimpleGLADecoderConfig + + def __init__(self, cfg: SimpleGLADecoderConfig): + super().__init__() + + assert cfg.dim % cfg.num_heads == 0 + assert cfg.blind_crossatt != (cfg.listen_read_crossatt is not None) + + if cfg.listen_read_crossatt is not None: + self.text_freqs_embd = ConvPos(cfg.dim) + else: + self.text_freqs_embd = None + self.head_dim = cfg.dim // cfg.num_heads + + def simple_gla_block(i): + conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers + if i in conv_layers: + return ShortConvBlock( + dim=cfg.dim, + kernel_size=4, + ffn_expansion_factor=cfg.ffn_expansion_factor, + layer_idx=i, + use_fast_conv1d=True, + ) + + else: + return SimpleGLABlock( + dim=cfg.dim, + num_heads=cfg.num_heads, + layer_idx=i, + expand_k=cfg.expand_k, + expand_v=cfg.expand_v, + use_short_conv=cfg.use_short_conv, + ffn_expansion_factor=cfg.ffn_expansion_factor, + ) + + def crossatt_block(i): + if cfg.blind_crossatt: + if i in cfg.crossatt_layer_idx: + return BlindCrossAttention( + cfg.dim, + cfg.dim, + cfg.dim, + simple_gla_block(cfg.num_layers), + pos_dim=cfg.dim, + dropout=cfg.crossatt_dropout, + pos_type="convolutional", + layer_idx=cfg.num_layers, + ) + elif cfg.listen_read_crossatt is not None: + if i in cfg.listen_read_crossatt.keys(): + return ListenReadCrossAttention( + cfg.dim, + cfg.dim, + cfg.dim, + cfg.listen_read_crossatt[i], + ) + + else: + if i in cfg.crossatt_layer_idx: + return CrossAttention( + dim=cfg.dim, + num_heads=cfg.crossatt_num_heads, + dropout=cfg.crossatt_dropout, + layer_idx=i, + ) + + self.decoder_layers = nn.ModuleList( + [ + DecoderBlockWithOptionalCrossAttention( + simple_gla_block(i), + crossatt_block(i), + ) + for i in range(cfg.num_layers) + ] + ) + + def forward( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + crossatt_mask: torch.Tensor | None = None, + cache: FLACache | None = None, + ): + x = decoder_input + if self.text_freqs_embd is not None: + text_freqs = torch.arange(encoder_output.shape[1], device=x.device)[None, :] + text_freqs = self.text_freqs_embd(text_freqs) + else: + text_freqs = None + for layer in self.decoder_layers: + x = maybe_grad_ckpt(layer)( + x, + encoder_output, + text_freqs=text_freqs, + cache=cache, + crossatt_mask=crossatt_mask, + ) + return x + + def init_cache(self, max_seq_len, device): + return FLACache(num_states=len(self.decoder_layers) + 1) + + def _get_query(self, audio_inputs: torch.Tensor, layer_idx: int): + assert self.decoder_layers[layer_idx].crossatt is not None + x = audio_inputs + for _, layer in zip(range(layer_idx - 1), self.decoder_layers): + x = layer(x, None) + return self.decoder_layers[layer_idx].crossatt._query(x) + + def prefill( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + cache: FLACache | None = None, + ): + return self(encoder_output, decoder_input, cache=cache) + + def decode_one( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + cache: Cache, + ): + x = decoder_input + for layer in self.decoder_layers: + x = layer( + x, + encoder_output, + cache=cache, + ) + return x diff --git a/tts/tts/model/transformer.py b/tts/tts/model/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f101a53d23301621a9b0907bf2fc70a133eb9ffb --- /dev/null +++ b/tts/tts/model/transformer.py @@ -0,0 +1,223 @@ +import os + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn +from transformers.cache_utils import DynamicCache + +from layers.attention import CrossAttention, SelfAttention, precompute_freqs_cis +from layers.ffn import SwiGLU +from model.cache_utils import FLACache, TransformerDecoderCache +from model.config import TransformerDecoderConfig, TransformerEncoderConfig +from model.registry import register_decoder, register_encoder +from model.shortconv import ShortConvBlock + +if "GRAD_CKPT" in os.environ: + + def maybe_grad_ckpt(f): + def grad_ckpt_f(*args, **kwargs): + return torch.utils.checkpoint.checkpoint( + f, *args, **kwargs, use_reentrant=False + ) + + return grad_ckpt_f +else: + + def maybe_grad_ckpt(f): + return f + + +class TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + layer_idx: int, + ffn_expansion_factor: int, + is_causal: bool, + ): + super().__init__() + self.tmix = SelfAttention(dim, num_heads, layer_idx, is_causal=is_causal) + self.cmix = SwiGLU(dim, ffn_expansion_factor) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + def forward( + self, + x, + freqs: torch.Tensor | None = None, + cache: DynamicCache | None = None, + mask: torch.Tensor | None = None, + ): + x = self.tmix(self.norm1(x), freqs=freqs, cache=cache, mask=mask) + x + x = self.cmix(self.norm2(x)) + x + return x + + +class DecoderBlockWithOptionalCrossAttention(nn.Module): + def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None): + super().__init__() + + self.decoder_block = decoder_block + self.crossatt = crossatt + + def forward( + self, + x: torch.Tensor, + encoder_output: torch.Tensor | None = None, + freqs: torch.Tensor | None = None, + cache: FLACache | None = None, + selfatt_mask: torch.Tensor | None = None, + crossatt_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + x = self.decoder_block( + x, + freqs=freqs, + cache=cache, + ) + + if self.crossatt is not None: + x = x + self.crossatt( + x, + k=encoder_output, + mask=crossatt_mask, + cache=cache, + ) + + return x + + +@register_decoder("sa_transformer") +class TransformerDecoder(nn.Module): + def __init__(self, cfg: TransformerDecoderConfig): + super().__init__() + + assert cfg.dim % cfg.num_heads == 0 + self.head_dim = cfg.dim // cfg.num_heads + + def transformer_block(i): + conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers + if i in conv_layers: + return ShortConvBlock( + dim=cfg.dim, + kernel_size=4, + ffn_expansion_factor=cfg.ffn_expansion_factor, + layer_idx=i, + use_fast_conv1d=True, + ) + else: + return TransformerBlock( + dim=cfg.dim, + num_heads=cfg.num_heads, + layer_idx=i, + ffn_expansion_factor=cfg.ffn_expansion_factor, + is_causal=True, + ) + + def crossatt_block(i): + return CrossAttention( + dim=cfg.dim, + num_heads=cfg.crossatt_num_heads, + dropout=cfg.crossatt_dropout, + layer_idx=i, + ) + + self.decoder_layers = nn.ModuleList( + [ + DecoderBlockWithOptionalCrossAttention( + transformer_block(i), + crossatt_block(i) if i in cfg.crossatt_layer_idx else None, + ) + for i in range(cfg.num_layers) + ] + ) + + def forward( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + crossatt_mask: torch.Tensor | None = None, + cache: FLACache | None = None, + ): + x = decoder_input + positions = torch.arange(x.shape[1], device=x.device) + freqs = precompute_freqs_cis(positions, self.head_dim) + for layer in self.decoder_layers: + x = maybe_grad_ckpt(layer)( + x, + encoder_output, + freqs=freqs, + crossatt_mask=crossatt_mask, + cache=cache, + ) + return x + + def init_cache(self, max_seq_len, device): + return FLACache(self.head_dim, max_seq_len * 2, device) + + def prefill( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + cache: FLACache | None = None, + ): + return self(encoder_output, decoder_input, cache=cache) + + def decode_one( + self, + encoder_output: torch.Tensor, + decoder_input: torch.Tensor, + cache: FLACache, + ): + x = decoder_input + pos = cache._seen_tokens + freq = cache.freqs[[pos]] + for layer in self.decoder_layers: + x = layer( + x, + encoder_output, + freqs=freq, + cache=cache, + ) + return x + + +@register_encoder("sa_transformer") +class TransformerEncoder(nn.Module): + config = TransformerEncoderConfig + + def __init__(self, cfg: TransformerEncoderConfig): + super().__init__() + assert cfg.dim % cfg.num_heads == 0 + self.head_dim = cfg.dim // cfg.num_heads + self.encoder_layers = nn.ModuleList( + [ + TransformerBlock( + cfg.dim, + cfg.num_heads, + i, + cfg.ffn_expansion_factor, + is_causal=False, + ) + for i in range(cfg.num_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None = None, + ): + positions = torch.arange(x.shape[1], device=x.device) + freqs = precompute_freqs_cis(positions, self.head_dim) + if mask is not None: + mask = rearrange(mask, "b n m -> b 1 n m") + mask = torch.logical_or( + mask, + rearrange(torch.eye(mask.shape[-1], device=x.device), "n m ->1 1 n m"), + ) + + for layer in self.encoder_layers: + x = layer(x, freqs=freqs, mask=mask) + return x diff --git a/tts/tts/tools.py b/tts/tts/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..4cbd6784b7ec80424547ac5acb38f411a7b7ddaa --- /dev/null +++ b/tts/tts/tools.py @@ -0,0 +1,77 @@ +from typing import Callable, List, Optional +from itertools import accumulate + +import torch + +default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def pad_2d_sequence(seq, padding_value=0): + max_x, max_y = map(max, zip(*map(lambda x: x.shape, seq))) + pad = lambda x: torch.nn.functional.pad( + x, + (0, max_y - x.shape[1], 0, max_x - x.shape[0]), + value=padding_value, + ) + return torch.stack([pad(x) for x in seq]) + +def packmask_2d(xlen: list[int], ylen: list[int], offset: int=0) -> torch.Tensor: + _, ybound = map(lambda x: [0] + list(accumulate(x, int.__add__)), (xlen, ylen)) + lb, hb = [], [] + + for n, l, h in zip(xlen, ybound[:-1], ybound[1:]): + lb += [l]*n + hb += [h]*n + + lb, hb = map(torch.tensor, (lb, hb)) + if offset: + lb -= offset + hb += offset + + rge = torch.arange(ybound[-1]) + + lm = rge.unsqueeze(0) >= lb.unsqueeze(1) + hm = rge.unsqueeze(0) < hb.unsqueeze(1) + + return lm * hm + + +def topk_sampling(seq, k=1, temp=1.): + topk = torch.topk(seq, k, dim=-1) + logits = seq / temp + mask = logits < topk.values[:, [-1]] + logits[mask] = -float('Inf') + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1) + +def delay_rvq( + code, + head_token: int = -2, + tail_token: int = -3, +): + q, _ = code.shape + extension = torch.ones((q, q + 1)).tril() * head_token + extension += torch.ones((q + 1, q)).tril(diagonal=-1).T * tail_token + extension = torch.flip(extension, (1,)) + extended_code = torch.cat((code, extension), axis=1) + for i in range(q): + extended_code[i, :] = torch.roll(extended_code[i, :], i + 1) + + return extended_code.long() + +def undelay_rvq(extended_code): + q, _, n = extended_code.shape + out = [] + for i in range(q): + out.append(torch.roll(extended_code[i], -(i + 1), dims=1)) + out = torch.stack(out, dim=0) + return out[:, :, :-(q+1)] + +def sequence_mask(lengths, max_len=None, device=default_device): + batch_size = lengths.shape[0] + if max_len is None: + max_len = torch.max(lengths).item() + + ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) + mask = ids < lengths.unsqueeze(1).expand(-1, max_len) + + return mask diff --git a/tts/tts/train_tts.py b/tts/tts/train_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..7f0d9a064c2d76048fdd7f78d55705213121a5f9 --- /dev/null +++ b/tts/tts/train_tts.py @@ -0,0 +1,197 @@ +import json +import math +from dataclasses import asdict +from pathlib import Path + +import hydra +import numpy as np +import pytorch_lightning as ptl +import torch +from omegaconf import DictConfig +from safetensors.torch import save_file +from torch import nn +from torch.optim.lr_scheduler import LambdaLR +from transformers import get_cosine_schedule_with_warmup + +from model.config import TTSConfig +from model.prediction_head import DiffusionHead +from tts import ARTTSModel + + +def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr): + def lr_lambda(step): + if step < warmup_steps: + return step / max(1, warmup_steps) + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) + return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr + + return lr_lambda + + +class TrainARTTS(ptl.LightningModule): + def __init__( + self, + config: TTSConfig, + quant_layer: list[int], + tie_embed: bool = False, + learning_rate: float = 5e-4, + end_learning_rate: float | None = None, + weight_decay: float = 0.1, + betas: tuple[float, float] = (0.9, 0.999), + n_warmup_steps: int = 500, + n_training_steps: int = 300000, + mask_text_p: float = 0.0, + load_weights: str | None = None, + stop_token_weight: float | None = None, + stop_loss_factor: float = 0.1, + stop_loss_warmup: tuple[int, int] | None = None, + ): + super(TrainARTTS, self).__init__() + + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.betas = betas + self.n_warmup_steps = n_warmup_steps + self.n_training_steps = n_training_steps + self.stop_token_weight = stop_token_weight + self.stop_loss_factor = stop_loss_factor + + self.save_hyperparameters() + + self.model = ARTTSModel(config) + + if load_weights is not None: + model = torch.load(load_weights) + self.load_state_dict(model["state_dict"], strict=False) + + def on_train_epoch_start(self): + if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"): + self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch) + + def save_model_weights_and_config( + self, + dir: str | None, + model_filename: str = "model.st", + config_filename: str = "config.json", + ): + cfg = self.hparams.config + Path(dir).mkdir(exist_ok=True) + model_path = Path(dir) / model_filename + save_file(self.model.state_dict(), model_path) + with open(Path(dir) / config_filename, "w") as f: + json.dump(asdict(cfg), f, indent=2) + + def step(self, batch, batch_idx: int, validation: bool = False): + text_token = batch["text_token"] + audio_token = batch["audio_token"].squeeze(2) + crossatt_mask = batch["crossatt_mask"] + text_rel_pos = batch["text_rel_pos"] + encoder_mask = batch["encoder_mask"] + logits_mask = batch["y_mask"] + + pre_logits = self.model( + text_token, + audio_token, + encoder_mask, + logits_mask, + crossatt_mask, + ) + losses = {} + if validation and type(self.model.prediction_head) is DiffusionHead: + t = ( + torch.ones(pre_logits.shape[0], device=pre_logits.device) + * batch_idx + / self.trainer.num_val_batches[0] + ) + losses |= self.model.prediction_head.compute_loss( + pre_logits, + audio_token[:, 1:], + mask=logits_mask[:, 1:] if logits_mask is not None else None, + t=t, + ) + else: + losses |= self.model.prediction_head.compute_loss( + pre_logits, + audio_token[:, 1:], + mask=logits_mask[:, 1:] if logits_mask is not None else None, + ) + + if self.model.stop_prediction_head is not None and logits_mask is not None: + target = nn.functional.pad((~logits_mask)[:, 2:].to(pre_logits), (0, 1)) + mask = logits_mask[:, 1:] + losses |= self.model.stop_prediction_head.compute_loss( + pre_logits[mask], + target[mask], + ) + + return losses + + def training_step(self, batch, idx): + losses = self.step(batch, idx) + total_loss = 0.0 + for name, loss in losses.items(): + self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True) + if "stop" in name: + if self.hparams.stop_loss_warmup is not None: + alpha, beta = self.hparams.stop_loss_warmup + warmup = np.clip((idx - alpha) / beta, a_min=0.0, a_max=1.0) + else: + warmup = 1.0 + loss *= self.stop_loss_factor * warmup + total_loss += loss + self.log("train_loss", total_loss, prog_bar=True, sync_dist=True) + return total_loss + + def validation_step(self, batch, idx): + losses = self.step(batch, idx, validation=True) + total_loss = 0.0 + for name, loss in losses.items(): + self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True) + total_loss += loss + self.log("val_loss", total_loss, prog_bar=True, sync_dist=True) + return total_loss + + def configure_optimizers(self): + params = [ + { + "params": self.model.parameters(), + "weight_decay": self.weight_decay, + } + ] + opt = torch.optim.AdamW( + params, + lr=self.learning_rate, + betas=self.betas, + ) + # scheduler = get_cosine_schedule_with_warmup( + # opt, + # num_warmup_steps=self.n_warmup_steps, + # num_training_steps=self.n_training_steps, + # ) + scheduler = LambdaLR( + opt, + lr_lambda=cosine_schedule_with_warmup( + warmup_steps=self.hparams.n_warmup_steps, + total_steps=self.hparams.n_training_steps, + start_lr=self.hparams.learning_rate, + end_lr=self.hparams.learning_rate * 0.1, + ), + ) + return [opt], [{"scheduler": scheduler, "interval": "step"}] + + +@hydra.main(config_path="hydra_configs/", config_name="config", version_base="1.3") +def main(cfg: DictConfig): + ptl.seed_everything(cfg.seed_everything) + + model = hydra.utils.instantiate(cfg.model) + cfg.experiment_name = f"ARTTS_{model.hparams.config.decoder_cfg.name}" + datamodule = hydra.utils.instantiate(cfg.data) + trainer = hydra.utils.instantiate(cfg.trainer) + + trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path")) + + +if __name__ == "__main__": + main() diff --git a/tts/tts/tts.py b/tts/tts/tts.py new file mode 100644 index 0000000000000000000000000000000000000000..425c070b749ff441916cb4f86cf2a89d4137b0f2 --- /dev/null +++ b/tts/tts/tts.py @@ -0,0 +1,158 @@ +import json +from pathlib import Path +from typing import Any + +import torch +from safetensors import safe_open +from torch import nn +from tqdm import tqdm + +from model.config import TTSConfig +from model.prediction_head import ( + ContinuousHead, + DiffusionHead, + LogitsHead, + StopPredictionHead, +) +from model.registry import DECODER_REGISTRY, ENCODER_REGISTRY + + +class ARTTSModel(nn.Module): + def __init__(self, cfg: TTSConfig): + super().__init__() + self.text_embd = nn.Embedding(cfg.text_vocab_size, cfg.dim) + if cfg.audio_input_type == "discrete": + self.audio_embd = nn.Embedding(cfg.audio_vocab_size, cfg.dim) + self.prediction_head = LogitsHead(cfg.decoder_cfg.dim, cfg.audio_vocab_size) + elif cfg.audio_input_type == "continuous" and cfg.continuous_diffusion: + self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim) + self.prediction_head = DiffusionHead( + cfg.decoder_cfg.dim, + cfg.audio_embed_size, + cfg.diffusion_head_num_layers, + ) + elif cfg.audio_input_type == "continuous": + self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim) + self.prediction_head = ContinuousHead( + cfg.decoder_cfg.dim, + cfg.audio_embed_size, + ) + + self.text_encoder = ENCODER_REGISTRY[cfg.encoder_cfg.name](cfg.encoder_cfg) + self.audio_decoder = DECODER_REGISTRY[cfg.decoder_cfg.name](cfg.decoder_cfg) + self.stop_prediction_head = ( + StopPredictionHead(cfg.dim) if cfg.stop_prediction_head else None + ) + + @classmethod + def from_pretrained_local( + cls, + path: str, + config_filename: str = "config.json", + model_filename: str = "model.st", + device: str = "cpu", + ): + with open(Path(path) / config_filename, "r") as f: + config = json.load(f) + for k in config.keys(): + if k == "decoder_cfg": + config[k] = DECODER_REGISTRY[config[k]["name"]].config(**config[k]) + if k == "encoder_cfg": + config[k] = ENCODER_REGISTRY[config[k]["name"]].config(**config[k]) + config = TTSConfig(**config) + model = ARTTSModel(config) + state_dict = {} + with safe_open(Path(path) / model_filename, framework="pt", device=device) as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + model.load_state_dict(state_dict) + return model + + def _get_query(self, x: torch.Tensor, *args): + input_audio_embd = self.audio_embd(x) + return self.audio_decoder._get_query(input_audio_embd, *args) + + def forward( + self, + text_ids: torch.LongTensor, + audio_inputs: torch.Tensor, + text_mask: torch.Tensor | None = None, + audio_mask: torch.Tensor | None = None, + crossatt_mask: torch.Tensor | None = None, + ): + input_text_embd = self.text_embd(text_ids) + input_audio_embd = self.audio_embd(audio_inputs[:, :-1]) + text_hidden_states = self.text_encoder(input_text_embd, mask=text_mask) + pre_logits = self.audio_decoder( + text_hidden_states, + input_audio_embd, + crossatt_mask=crossatt_mask[:, :-1] if crossatt_mask is not None else None, + ) + """ + losses = {} + losses |= self.prediction_head.compute_loss( + pre_logits, + audio_inputs[:, 1:], + mask=audio_mask[:, 1:] if audio_mask is not None else None, + ) + if self.stop_prediction_head is not None and audio_mask is not None: + target = nn.functional.pad((~audio_mask)[:, 2:].to(pre_logits), (0, 1)) + mask = audio_mask[:, 1:] + losses |= self.stop_prediction_head.compute_loss( + pre_logits[mask], + target[mask], + ) + return losses + """ + return pre_logits + + def generate( + self, + text_ids: torch.LongTensor, + prefix: torch.Tensor, + teacher_force: torch.Tensor | None = None, + max_seq_len: int = 200, + device: str = "cuda", + sampling_params: dict | None = None, + stop_threshold: float = 0.5, + ): + if sampling_params is None: + sampling_params = {} + if text_ids.ndim == 1: + text_ids = text_ids.unsqueeze(0) + + input_text_embd = self.text_embd(text_ids) + text_hidden_states = self.text_encoder(input_text_embd) + prefix_embd = self.audio_embd(prefix) + + cache = self.audio_decoder.init_cache(max_seq_len, device) + preds = [] + pre_prediction = self.audio_decoder.prefill( + text_hidden_states, + prefix_embd, + cache=cache, + ) + prediction = self.prediction_head.predict(pre_prediction[:, [-1]]) + prediction_embd = self.audio_embd(prediction) + for i in tqdm(range(max_seq_len)): + pre_prediction = self.audio_decoder.decode_one( + input_text_embd, + prediction_embd, + cache, + ) + prediction = self.prediction_head.predict(pre_prediction, **sampling_params) + prediction_embd = self.audio_embd(prediction) + if teacher_force is not None: + b, n, d = teacher_force.shape + if i < n: + prediction_embd = self.audio_embd(teacher_force[:, [i]]) + + preds.append(prediction) + if self.stop_prediction_head is not None: + stop_pred = self.stop_prediction_head(pre_prediction) + if stop_pred > stop_threshold: + break + + full_prediction = torch.cat(preds, dim=1) + + return cache, full_prediction