diff --git a/codec/__init__.py b/codec/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2285f83140d9290b9ad285eabb89efa191432cc
--- /dev/null
+++ b/codec/__init__.py
@@ -0,0 +1,2 @@
+from .train_patchvae import TrainPatchVAE
+from .train_wavvae import TrainWavVAE
diff --git a/codec/__pycache__/__init__.cpython-312.pyc b/codec/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d164b12040190789bd63d3cdca5def40be95972
Binary files /dev/null and b/codec/__pycache__/__init__.cpython-312.pyc differ
diff --git a/codec/__pycache__/train_patchvae.cpython-312.pyc b/codec/__pycache__/train_patchvae.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e7b8daeb0fa80c52c7dae20156695344c3552fce
Binary files /dev/null and b/codec/__pycache__/train_patchvae.cpython-312.pyc differ
diff --git a/codec/__pycache__/train_wavvae.cpython-312.pyc b/codec/__pycache__/train_wavvae.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39cc6c30402f689090304bc838f9c7d1cc2ea3d2
Binary files /dev/null and b/codec/__pycache__/train_wavvae.cpython-312.pyc differ
diff --git a/codec/__pycache__/train_zflowae.cpython-312.pyc b/codec/__pycache__/train_zflowae.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e456645ae126417fcb45db517aea7ba12413502
Binary files /dev/null and b/codec/__pycache__/train_zflowae.cpython-312.pyc differ
diff --git a/codec/datamodules.py b/codec/datamodules.py
new file mode 100755
index 0000000000000000000000000000000000000000..11bb24ea7964b834d4ae19cf774529bb822933ee
--- /dev/null
+++ b/codec/datamodules.py
@@ -0,0 +1,249 @@
+import itertools
+import random
+import time
+from dataclasses import dataclass
+from functools import partial
+from pathlib import Path
+
+import numpy as np
+import pytorch_lightning as ptl
+import torch
+import torchaudio
+from safetensors.torch import safe_open
+from sklearn.model_selection import train_test_split
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import DataLoader, Dataset
+
+from datasets import load_dataset, load_from_disk
+
+
+@dataclass
+class WavVAEDataConfig:
+ filelist_path: str
+ sampling_rate: int
+ num_samples: int
+ batch_size: int
+ num_workers: int
+
+
+class WavVAEDataModule(ptl.LightningDataModule):
+ def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig):
+ super().__init__()
+ self.train_config = train_params
+ self.val_config = val_params
+
+ def _get_dataloder(self, cfg: WavVAEDataConfig, train: bool):
+ dataset = WavVAEDataset(cfg, train=train)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=cfg.batch_size,
+ num_workers=cfg.num_workers,
+ shuffle=train,
+ pin_memory=True,
+ )
+ return dataloader
+
+ def train_dataloader(self) -> DataLoader:
+ return self._get_dataloder(self.train_config, train=True)
+
+ def val_dataloader(self) -> DataLoader:
+ return self._get_dataloder(self.val_config, train=False)
+
+
+class WavVAEDataset(Dataset):
+ def __init__(self, cfg: WavVAEDataConfig, train: bool):
+ with open(cfg.filelist_path) as f:
+ self.filelist = f.read().splitlines()
+ self.sampling_rate = cfg.sampling_rate
+ self.num_samples = cfg.num_samples
+ self.train = train
+
+ def __len__(self) -> int:
+ return len(self.filelist)
+
+ def __getitem__(self, index: int) -> torch.Tensor:
+ audio_path = self.filelist[index]
+ y, sr = torchaudio.load(audio_path)
+ if y.size(0) > 1:
+ # mix to mono
+ y = y.mean(dim=0, keepdim=True)
+ gain = np.random.uniform(-1, -6) if self.train else -3
+ y, _ = torchaudio.sox_effects.apply_effects_tensor(
+ y, sr, [["norm", f"{gain:.2f}"]]
+ )
+ if sr != self.sampling_rate:
+ y = torchaudio.functional.resample(
+ y, orig_freq=sr, new_freq=self.sampling_rate
+ )
+ if y.size(-1) < self.num_samples:
+ pad_length = self.num_samples - y.size(-1)
+ padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
+ y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
+ elif self.train:
+ start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
+ y = y[:, start : start + self.num_samples]
+ else:
+ # During validation, take always the first segment for determinism
+ y = y[:, : self.num_samples]
+
+ return y[0]
+
+
+def pad_tensor_list_raw(
+ tensor_list: list[tuple[torch.Tensor, torch.Tensor]], pad_idx: int = 0
+) -> dict[str, torch.Tensor | None]:
+ audio, hubert_maybe = zip(*tensor_list)
+ audio = torch.cat(audio, dim=0)
+ if hubert_maybe[0] is not None:
+ hubert_maybe = torch.stack(hubert_maybe, dim=0)
+ else:
+ hubert_maybe = None
+ return {"audio_z": audio, "hubert": hubert_maybe}
+
+
+class SafeTensorDataset(Dataset):
+ """
+ On __getitem__, opens the safetensor, uses get_slice() to inspect shape,
+ then either drops too-short files (return None) or returns a random subsequence slice.
+ """
+
+ def __init__(
+ self,
+ file_paths: list[str],
+ key: str,
+ hubert_path: str | None = None,
+ hubert_key: str = "layer_9",
+ min_length: int = 1,
+ subseq_length: int | None = None,
+ ):
+ self.file_paths = file_paths
+ self.key = key
+ self.min_length = min_length
+ self.subseq_length = subseq_length
+ self.hubert_path = hubert_path
+ self.hubert_key = hubert_key
+
+ def __len__(self):
+ return len(self.file_paths)
+
+ def __getitem__(self, idx: int) -> torch.Tensor | None:
+ path = self.file_paths[idx]
+ # open file, get a slice wrapper for full tensor
+ with safe_open(path, framework="pt") as f:
+ tensor_slice = f.get_slice(self.key)
+ Q, N, D = tensor_slice.get_shape() # full shape [K, N]
+
+ # drop too-short
+ if N < self.min_length:
+ return None
+
+ L = self.subseq_length or N
+ if L < N:
+ # sample random start
+ start = torch.randint(0, max(1, N - L - 1), ()).item()
+ start -= start % 2
+ # this yields a torch.Tensor of shape [K, L]
+ seq = tensor_slice[:, start : start + L]
+ else:
+ # full length
+ start = 0
+ seq = tensor_slice[:, :]
+
+ if self.hubert_path is not None:
+ path = Path(self.hubert_path) / Path(path).name
+ with safe_open(path, framework="pt") as f:
+ tensor_slice = f.get_slice(self.hubert_key)
+ hubert_N, hubert_D = tensor_slice.get_shape() # full shape [K, N]
+ seq_hubert = tensor_slice[start // 2 : start // 2 + L // 2]
+ return (seq, seq_hubert)
+
+ return (seq, None)
+
+
+class SafeTensorDataModule(ptl.LightningDataModule):
+ """
+ LightningDataModule using raw .safetensors file list + get_slice inside Dataset.
+ """
+
+ def __init__(
+ self,
+ train_file_list: str,
+ val_file_list: str | None = None,
+ hubert_path: str | None = None,
+ key: str = "audio_z",
+ hubert_key: str = "layer_9",
+ val_split: float = 0.1,
+ batch_size: int = 32,
+ num_workers: int = 4,
+ shuffle: bool = True,
+ seed: int = 1234,
+ min_length: int = 1,
+ subseq_length: int | None = None,
+ ):
+ super().__init__()
+ self.train_file_list = train_file_list
+ self.val_file_list = val_file_list
+ self.hubert_path = hubert_path
+ self.key = key
+ self.val_split = val_split
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.shuffle = shuffle
+ self.seed = seed
+ self.min_length = min_length
+ self.subseq_length = subseq_length
+
+ def setup(self, stage=None):
+ with open(self.train_file_list, "r") as f:
+ train_paths = [line.strip() for line in f if line.strip()]
+ val_paths = None
+ if self.val_file_list is not None:
+ with open(self.train_file_list, "r") as f:
+ val_paths = [line.strip() for line in f if line.strip()]
+ # Split into train/val
+ if (
+ isinstance(self.val_split, float)
+ and 0 < self.val_split < 1
+ and val_paths is None
+ ):
+ train_paths, val_paths = train_test_split(
+ train_paths, test_size=self.val_split, random_state=self.seed
+ )
+
+ self.train_ds = SafeTensorDataset(
+ train_paths,
+ key=self.key,
+ min_length=self.min_length,
+ subseq_length=self.subseq_length,
+ hubert_path=self.hubert_path,
+ )
+ self.val_ds = SafeTensorDataset(
+ val_paths,
+ key=self.key,
+ min_length=self.min_length,
+ subseq_length=self.subseq_length,
+ )
+
+ def _collate_fn(
+ self, batch: list[torch.Tensor | None]
+ ) -> tuple[torch.Tensor, torch.BoolTensor]:
+ seqs = [s for s in batch if s is not None]
+ return pad_tensor_list_raw(seqs, pad_idx=0)
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_ds,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ collate_fn=self._collate_fn,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_ds,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ collate_fn=self._collate_fn,
+ )
diff --git a/codec/models/__init__.py b/codec/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92b4b9a8a5f448b2d6f900eb72af1fef16b4308
--- /dev/null
+++ b/codec/models/__init__.py
@@ -0,0 +1,2 @@
+from .patchvae.model import PatchVAE, PatchVAEConfig
+from .wavvae.model import WavVAE, WavVAEConfig
diff --git a/codec/models/__pycache__/__init__.cpython-312.pyc b/codec/models/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f53ceff8a05990134106404052316c1e5731ec17
Binary files /dev/null and b/codec/models/__pycache__/__init__.cpython-312.pyc differ
diff --git a/codec/models/components/__init__.py b/codec/models/components/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/codec/models/components/__pycache__/__init__.cpython-312.pyc b/codec/models/components/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..909df039a1ac1c806d329d13b98c416ff9f32e17
Binary files /dev/null and b/codec/models/components/__pycache__/__init__.cpython-312.pyc differ
diff --git a/codec/models/components/__pycache__/convnext.cpython-312.pyc b/codec/models/components/__pycache__/convnext.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18c6696c356c110c642db2815bb0049c58bf83f3
Binary files /dev/null and b/codec/models/components/__pycache__/convnext.cpython-312.pyc differ
diff --git a/codec/models/components/convnext.py b/codec/models/components/convnext.py
new file mode 100755
index 0000000000000000000000000000000000000000..8927a3c9f22b2ae22b45b78c41a81b8bb03fea22
--- /dev/null
+++ b/codec/models/components/convnext.py
@@ -0,0 +1,221 @@
+import torch
+from torch import nn
+
+
+class ConvNeXtBlock(nn.Module):
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
+
+ Args:
+ dim (int): Number of input channels.
+ intermediate_dim (int): Dimensionality of the intermediate layer.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional LayerNorm. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int | None = None,
+ layer_scale_init_value: float = 0.0,
+ elementwise_affine_ln: bool = True,
+ is_causal: bool = False,
+ ):
+ super().__init__()
+ intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3
+ self.dwconv = nn.Conv1d(
+ dim, dim, kernel_size=7, padding=0 if is_causal else 3, groups=dim
+ ) # depthwise conv
+ self.norm = nn.LayerNorm(
+ dim, eps=1e-6, elementwise_affine=elementwise_affine_ln
+ )
+ self.pwconv1 = nn.Linear(
+ dim, intermediate_dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.is_causal = is_causal
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
+ gate: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ residual = x
+ if self.is_causal:
+ x = torch.nn.functional.pad(x, (6, 0))
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
+ x = self.norm(x)
+ if scale_shift is not None:
+ scale, shift = scale_shift
+ x = x * scale[:, None] + shift[:, None]
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ if gate is not None:
+ x = gate[:, None] * x
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
+
+ x = residual + x
+ return x
+
+
+class ConvNextNet(nn.Module):
+ def __init__(self, n_layers, dim, intermediate_dim: int | None = None):
+ super().__init__()
+ self.net = nn.Sequential(
+ *[
+ ConvNeXtBlock(
+ dim,
+ intermediate_dim,
+ )
+ for _ in range(n_layers)
+ ]
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class ConvNextPatchEncoder(nn.Module):
+ def __init__(
+ self,
+ patch_sizes: list[int],
+ n_layers_per_patch: int,
+ patch_expansion_factor: float = 1.5,
+ is_decoder: bool = False,
+ ):
+ super().__init__()
+ patch_to_dim = []
+ convnext = []
+ for i, patch_size in enumerate(patch_sizes):
+ in_dim = int((patch_expansion_factor if i > 0 else 1.0) * patch_size)
+ out_dim = int(patch_expansion_factor * patch_size)
+ if is_decoder:
+ in_dim, out_dim = out_dim, in_dim
+ patch_to_dim.append(
+ nn.Linear(
+ in_dim,
+ out_dim,
+ )
+ )
+ convnext += [
+ nn.Sequential(
+ *[
+ ConvNeXtBlock(int(patch_size * patch_expansion_factor))
+ for _ in range(n_layers_per_patch)
+ ]
+ )
+ ]
+ self.is_decoder = is_decoder
+ self.patch_sizes = patch_sizes
+ self.patch_expansion_factor = patch_expansion_factor
+ self.patch_to_dim = nn.ModuleList(patch_to_dim)
+ self.convnext = nn.ModuleList(convnext)
+
+ def forward(self, x):
+ if self.is_decoder:
+ for i, patch_size in reversed(list(enumerate(self.patch_sizes))):
+ B, P, N = x.shape
+ patch_expansion_factor_maybe = (
+ self.patch_expansion_factor if i > 0 else 1.0
+ )
+ x = x.reshape(B, int(patch_size * self.patch_expansion_factor), -1)
+ x = self.convnext[i](x)
+ x = self.patch_to_dim[i](x.transpose(1, 2)).transpose(1, 2)
+ else:
+ for i, patch_size in enumerate(self.patch_sizes):
+ B, P, N = x.shape
+ patch_expansion_factor_maybe = (
+ self.patch_expansion_factor if i > 0 else 1.0
+ )
+ x = x.reshape(B, int(patch_size * patch_expansion_factor_maybe), -1)
+ x = self.patch_to_dim[i](x.transpose(1, 2)).transpose(1, 2)
+ x = self.convnext[i](x)
+ return x
+
+
+class ConvNextEncoder(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ dim: int,
+ n_layers: int,
+ intermediate_dim: int | None = None,
+ stride: int = 1,
+ ):
+ super().__init__()
+ self.in_proj = nn.Linear(in_dim, dim)
+ if stride > 1:
+ self.stride = nn.Conv1d(
+ in_channels=dim,
+ out_channels=dim,
+ kernel_size=(stride * 2) + 1,
+ stride=stride,
+ padding=stride // 2,
+ )
+ else:
+ self.stride = nn.Identity()
+ self.net = ConvNextNet(n_layers, dim, intermediate_dim)
+
+ def forward(self, x):
+ x = self.in_proj(x.transpose(1, 2)).transpose(1, 2)
+ x = self.stride(x)
+ return self.net(x)
+
+
+class ConvNextDecoder(nn.Module):
+ def __init__(
+ self,
+ out_dim: int,
+ dim: int,
+ n_layers: int,
+ intermediate_dim: int | None = None,
+ stride: int = 1,
+ stride_position: str = "before",
+ ):
+ super().__init__()
+ self.out_proj = nn.Linear(dim, out_dim)
+ if stride > 1:
+ self.stride = nn.ConvTranspose1d(
+ in_channels=dim,
+ out_channels=dim,
+ kernel_size=(stride * 2) + 1,
+ stride=stride,
+ padding=stride // 2,
+ output_padding=stride // 2,
+ )
+ else:
+ self.stride = nn.Identity()
+ self.stride_position = stride_position
+
+ self.net = ConvNextNet(n_layers, dim, intermediate_dim)
+
+ def forward(self, x):
+ if self.stride_position == "before":
+ x = self.stride(x)
+ x = self.net(x)
+ if self.stride_position == "after":
+ x = self.stride(x)
+ return self.out_proj(x.transpose(1, 2)).transpose(1, 2)
+
+
+class SwiGLU(nn.Module):
+ def __init__(self, d_model: int, ffn_expansion_factor: int = 4):
+ super().__init__()
+ self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2)
+ self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model)
+
+ def forward(self, x):
+ gate, x = self.p_in(x).chunk(2, dim=-1)
+ return self.p_out(nn.functional.silu(gate) * x)
diff --git a/codec/models/components/transformer.py b/codec/models/components/transformer.py
new file mode 100755
index 0000000000000000000000000000000000000000..23b8fee796c20881ea2454c80c23a45dd45c54b6
--- /dev/null
+++ b/codec/models/components/transformer.py
@@ -0,0 +1,224 @@
+from typing import Dict, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
+
+
+class LocalSelfAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ heads: int,
+ window_len: int = 32,
+ rotary: bool = True,
+ is_causal: bool = False,
+ ):
+ super().__init__()
+ self.heads = heads
+ assert dim % heads == 0, "dim must be divisible by heads"
+ self.qkv = nn.Linear(dim, 3 * dim)
+ self.o = nn.Linear(dim, dim)
+ self.rotary = RotaryEmbedding((dim // heads) // 2) if rotary else None
+ self.is_causal = is_causal
+ self.window_len = window_len
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ pos: Optional[torch.Tensor] = None,
+ cache: Optional[Dict[int, torch.Tensor]] = None,
+ layer_idx: Optional[int] = None,
+ time_step: int = 0,
+ ) -> torch.Tensor:
+ # x: (batch, seq_len, dim)
+ b, n, dim = x.shape
+ b, t_len, hd = x.shape
+ pad_len = (self.window_len - t_len % self.window_len) % self.window_len
+ padded_x = torch.nn.functional.pad(x, (0, 0, 0, pad_len)) # pad on time dim
+ mask = torch.ones(t_len, dtype=torch.bool, device=x.device)
+ mask = torch.nn.functional.pad(
+ mask, (0, pad_len), value=False
+ ) # False = masked
+ mask = mask.expand(b, -1) # [b, padded_len]
+ mask = rearrange(mask, "b (w n) -> b n 1 1 w", w=self.window_len)
+ qkv = self.qkv(padded_x).chunk(3, dim=-1)
+ q, k, v = [
+ rearrange(t, "b (w n) (h d) -> b n h w d", h=self.heads, w=self.window_len)
+ for t in qkv
+ ]
+ if cache is not None:
+ assert layer_idx is not None, "layer_idx must be set when using cache"
+ cache[layer_idx]["k"] = torch.cat([cache[layer_idx]["k"], k], dim=2)
+ cache[layer_idx]["v"] = torch.cat([cache[layer_idx]["v"], v], dim=2)
+ k, v = cache[layer_idx]["k"], cache[layer_idx]["v"]
+
+ # apply rotary embeddings
+ if self.rotary is not None:
+ if pos is not None:
+ rot = self.rotary(pos) # (b,1,n,head_dim)
+ q = apply_rotary_emb(rot, q)
+ k = apply_rotary_emb(rot, k)
+ else:
+ q = self.rotary.rotate_queries_or_keys(q, offset=time_step)
+ k = self.rotary.rotate_queries_or_keys(k, offset=time_step)
+
+ # scaled dot-product attention
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=None if self.is_causal else mask,
+ is_causal=self.is_causal,
+ )
+ y = rearrange(y, "b n h w d -> b (w n) (h d)")
+ y = self.o(y)
+ y = y[:, :t_len]
+ return y
+
+
+class SelfAttention(nn.Module):
+ def __init__(
+ self, dim: int, heads: int, rotary: bool = True, is_causal: bool = False
+ ):
+ super().__init__()
+ self.heads = heads
+ assert dim % heads == 0, "dim must be divisible by heads"
+ self.qkv = nn.Linear(dim, 3 * dim)
+ self.o = nn.Linear(dim, dim)
+ self.rotary = RotaryEmbedding((dim // heads) // 2) if rotary else None
+ self.is_causal = is_causal
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ pos: Optional[torch.Tensor] = None,
+ cache: Optional[Dict[int, torch.Tensor]] = None,
+ layer_idx: Optional[int] = None,
+ time_step: int = 0,
+ ) -> torch.Tensor:
+ # x: (batch, seq_len, dim)
+ b, n, dim = x.shape
+ b, t_len, hd = x.shape
+ pad_len = (32 - t_len % 32) % 32
+ padded_x = torch.nn.functional.pad(x, (0, 0, 0, pad_len)) # pad on time dim
+ mask = torch.ones(t_len, dtype=torch.bool, device=x.device)
+ mask = torch.nn.functional.pad(
+ mask, (0, pad_len), value=False
+ ) # False = masked
+ mask = mask.expand(b, -1) # [b, padded_len]
+ mask = rearrange(mask, "b (w n) -> b n 1 1 w", w=32)
+ qkv = self.qkv(padded_x).chunk(3, dim=-1)
+ q, k, v = [
+ rearrange(t, "b (w n) (h d) -> b n h w d", h=self.heads, w=32) for t in qkv
+ ]
+ # caching for fast autoregressive
+ if cache is not None:
+ assert layer_idx is not None, "layer_idx must be set when using cache"
+ # append new keys/values
+ cache[layer_idx]["k"] = torch.cat([cache[layer_idx]["k"], k], dim=2)
+ cache[layer_idx]["v"] = torch.cat([cache[layer_idx]["v"], v], dim=2)
+ k, v = cache[layer_idx]["k"], cache[layer_idx]["v"]
+
+ # apply rotary embeddings
+ if self.rotary is not None:
+ if pos is not None:
+ rot = self.rotary(pos) # .unsqueeze(1) # (b,1,n,head_dim)
+ q = apply_rotary_emb(rot, q)
+ k = apply_rotary_emb(rot, k)
+ else:
+ q = self.rotary.rotate_queries_or_keys(q, offset=time_step)
+ k = self.rotary.rotate_queries_or_keys(k, offset=time_step)
+
+ # scaled dot-product attention
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=None if self.is_causal else mask,
+ is_causal=self.is_causal,
+ )
+ y = rearrange(y, "b n h w d -> b (w n) (h d)")
+ y = self.o(y)
+ y = y[:, :t_len]
+ return y
+
+
+class SwiGLU(nn.Module):
+ def __init__(self, d_model: int):
+ super().__init__()
+ hidden = d_model * 4 // 3
+ self.p_in = nn.Linear(d_model, hidden * 2)
+ self.p_out = nn.Linear(hidden, d_model)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ gate, data = self.p_in(x).chunk(2, dim=-1)
+ return self.p_out(F.silu(gate) * data)
+
+
+class TransformerBlock(nn.Module):
+ """
+ Transformer block using custom SelfAttention and SwiGLU FFN.
+
+ Args:
+ dim: embedding dimension
+ heads: number of attention heads
+ rotary: whether to use rotary embeddings
+ is_causal: whether to apply causal masking
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ head_size: int,
+ rotary: bool = True,
+ is_causal: bool = False,
+ elementwise_affine_ln: bool = True,
+ ):
+ super().__init__()
+ assert dim % head_size == 0
+ heads = dim // head_size
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine_ln)
+ self.attn = LocalSelfAttention(dim, heads, rotary=rotary, is_causal=is_causal)
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine_ln)
+ self.ffn = SwiGLU(dim)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ pos: Optional[torch.Tensor] = None,
+ cache: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
+ layer_idx: Optional[int] = None,
+ scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
+ gate: torch.Tensor = None,
+ time_step: int = 0,
+ ) -> torch.Tensor:
+ # Self-attention block
+ norm1_x = self.norm1(x)
+ if scale_shift is not None:
+ scale, shift = scale_shift
+ norm1_x = norm1_x * scale[:, None] + shift[:, None]
+
+ attn_out = self.attn(
+ norm1_x,
+ mask=mask,
+ pos=pos,
+ cache=cache,
+ layer_idx=layer_idx,
+ time_step=time_step,
+ )
+ x = x + attn_out
+
+ norm2_x = self.norm2(x)
+ if gate is not None:
+ norm2_x = gate[:, None] * norm2_x
+
+ # Feedforward block
+ ffn_out = self.ffn(norm2_x)
+ x = x + ffn_out
+ return x
diff --git a/codec/models/pardi_tokenizer.py b/codec/models/pardi_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e72b9d344176b4f73eaed42ef1e35c7653c3b89a
--- /dev/null
+++ b/codec/models/pardi_tokenizer.py
@@ -0,0 +1,10 @@
+import torch
+
+from zcodec.models import WavVAE, ZFlowAutoEncoder
+from zcodec.models.wavvae.model import WavVAEConfig
+from zcodec.models.zflowae.model import ZFlowAutoEncoderConfig
+
+
+class PardiTokenizer(nn.Module):
+ def __init__(self, wavvae_cfg: WavVAEConfig, zflowae_cfg: ZFlowAutoEncoderConfig):
+
diff --git a/codec/models/patchvae/__pycache__/model.cpython-312.pyc b/codec/models/patchvae/__pycache__/model.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65688e4a2e0456dada8c1ba35f91cf9d5e0719dd
Binary files /dev/null and b/codec/models/patchvae/__pycache__/model.cpython-312.pyc differ
diff --git a/codec/models/patchvae/__pycache__/modules.cpython-312.pyc b/codec/models/patchvae/__pycache__/modules.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9e0f456209503ae551466c3d2791be1c461c131
Binary files /dev/null and b/codec/models/patchvae/__pycache__/modules.cpython-312.pyc differ
diff --git a/codec/models/patchvae/model.py b/codec/models/patchvae/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..a67f863fe76e8a398de9eaed745fdac757d590d6
--- /dev/null
+++ b/codec/models/patchvae/model.py
@@ -0,0 +1,262 @@
+import json
+import math
+import os
+import sys
+from contextlib import contextmanager
+from dataclasses import dataclass
+from pathlib import Path
+
+import torch
+from safetensors.torch import load_file
+from torch import nn
+from torchdyn.core import NeuralODE
+
+from .modules import AdaLNFlowPredictor, AutoEncoder
+
+
+@contextmanager
+def suppress_stdout():
+ original_stdout = sys.stdout
+ try:
+ sys.stdout = open(os.devnull, "w")
+ yield
+ finally:
+ sys.stdout.close()
+ sys.stdout = original_stdout
+
+
+def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
+ def lr_lambda(step):
+ if step < warmup_steps:
+ return step / max(1, warmup_steps)
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
+ return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
+
+ return lr_lambda
+
+
+@dataclass
+class PatchVAEConfig:
+ latent_dim: int
+ hidden_dim: int
+ latent_scaling: tuple[list[float], list[float]] | None
+ flow_factory: str
+ num_flow_layers: int
+ autoencoder_factory: str
+ num_autoencoder_layers: int
+ convnextformer_num_conv_per_transformer: int = 3
+ wavvae_path: str | None = None
+ fsq_levels: list[int] | None = None
+ bottleneck_size: int | None = None
+ latent_stride: int = 2
+ vae: bool = False
+ causal_transformer: bool = False
+ cond_dim: int | None = None
+ is_causal: bool = False
+
+
+class PatchVAE(nn.Module):
+ def __init__(self, cfg: PatchVAEConfig):
+ super().__init__()
+ self.flow_net = AdaLNFlowPredictor(
+ feat_dim=cfg.latent_dim * cfg.latent_stride,
+ dim=cfg.hidden_dim,
+ n_layer=cfg.num_flow_layers,
+ layer_factory=cfg.flow_factory,
+ cond_dim=cfg.cond_dim,
+ is_causal=cfg.is_causal,
+ )
+ self.autoencoder = AutoEncoder(
+ cfg.latent_dim * cfg.latent_stride,
+ cfg.hidden_dim,
+ cfg.num_autoencoder_layers,
+ cfg.autoencoder_factory,
+ out_dim=cfg.cond_dim,
+ vae=cfg.vae,
+ bottleneck_size=cfg.bottleneck_size,
+ convnextformer_num_conv_per_transformer=cfg.convnextformer_num_conv_per_transformer,
+ is_causal=cfg.is_causal,
+ )
+ if cfg.latent_scaling is not None:
+ mean, std = cfg.latent_scaling
+ self.register_buffer("mean_latent_scaling", torch.tensor(mean))
+ self.register_buffer("std_latent_scaling", torch.tensor(std))
+ else:
+ self.mean_latent_scaling = None
+ self.std_latent_scaling = None
+
+ self.latent_stride = cfg.latent_stride
+ self.latent_dim = cfg.latent_dim
+ self.wavvae = None
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str,
+ map_location: str = "cpu",
+ ):
+ if Path(pretrained_model_name_or_path).exists():
+ path = pretrained_model_name_or_path
+ else:
+ from huggingface_hub import snapshot_download
+
+ path = snapshot_download(pretrained_model_name_or_path)
+
+ with open(Path(path) / "config.json", "r") as f:
+ config = json.load(f)
+ config = PatchVAEConfig(**config)
+ model = cls(config).to(map_location)
+ state_dict = load_file(
+ Path(path) / "model.st",
+ device=map_location,
+ )
+ model.load_state_dict(state_dict, assign=True)
+ if config.wavvae_path is not None:
+ from .. import WavVAE
+
+ model.wavvae = WavVAE.from_pretrained(config.wavvae_path).to(map_location)
+ else:
+ model.wavvae = None
+
+ return model
+
+ def wavvae_from_pretrained(
+ self,
+ pretrained_model_name_or_path: str,
+ *args,
+ **kwargs,
+ ):
+ from .. import WavVAE
+
+ self.wavvae = WavVAE.from_pretrained(
+ pretrained_model_name_or_path,
+ *args,
+ **kwargs,
+ )
+
+ def encode(self, wav: torch.Tensor):
+ assert self.wavvae is not None, (
+ "please provide WavVAE model to encode from waveform"
+ )
+ z = self.wavvae.encode(wav)
+ zz = self.encode_patch(z)
+ return zz
+
+ def decode(self, patchvae_latent: torch.Tensor, **kwargs):
+ assert self.wavvae is not None, (
+ "please provide WavVAE model to decode to waveform"
+ )
+ z = self.decode_patch(patchvae_latent, **kwargs)
+ wav = self.wavvae.decode(z)
+ return wav
+
+ def normalize_z(self, z: torch.Tensor):
+ if self.mean_latent_scaling is not None:
+ z = (z - self.mean_latent_scaling) / self.std_latent_scaling
+ return z
+
+ def denormalize_z(self, z: torch.Tensor):
+ if self.std_latent_scaling is not None:
+ z = z * self.std_latent_scaling + self.mean_latent_scaling
+ return z
+
+ def encode_patch(self, z: torch.Tensor, deterministic: bool = False):
+ B, T, D = z.shape
+ z = self.normalize_z(z)
+ if self.latent_stride > 1:
+ z = z[:, : T - T % self.latent_stride]
+ z = z.reshape(B, T // self.latent_stride, D * self.latent_stride)
+ return self.autoencoder.encode(z, deterministic=deterministic)
+
+ def decode_patch(
+ self,
+ latent: torch.Tensor,
+ cfg: float = 2.0,
+ num_steps: int = 15,
+ solver: str = "euler",
+ sensitivity: str = "adjoint",
+ temperature: float = 1.0,
+ **kwargs,
+ ):
+ with torch.no_grad():
+ z_cond = self.autoencoder.decode(latent).transpose(1, 2)
+ if cfg == 1.0:
+
+ def solver_fn(t, Xt, *args, **kwargs):
+ flow = self.flow_net(Xt, z_cond, t.unsqueeze(0))
+ return flow
+ else:
+ z_cond_uncond = torch.cat((z_cond, torch.zeros_like(z_cond)), dim=0)
+
+ def solver_fn(t, Xt, *args, **kwargs):
+ flow = self.flow_net(
+ Xt.repeat(2, 1, 1), z_cond_uncond, t.unsqueeze(0)
+ )
+ cond, uncond = flow.chunk(2, dim=0)
+
+ return uncond + cfg * (cond - uncond)
+
+ with suppress_stdout():
+ node_ = NeuralODE(
+ solver_fn,
+ solver=solver,
+ sensitivity=sensitivity,
+ **kwargs,
+ )
+ t_span = torch.linspace(0, 1, num_steps + 1, device=z_cond.device)
+ patch_dim = self.latent_dim * self.latent_stride
+ x0 = torch.randn(
+ z_cond.shape[0],
+ patch_dim,
+ z_cond.shape[2],
+ device=z_cond.device,
+ )
+ traj = node_.trajectory(
+ x0 * temperature,
+ t_span=t_span,
+ )
+
+ y_hat = traj[-1]
+ y_hat = y_hat.transpose(1, 2)
+ B, T, D = y_hat.shape
+ y_hat = y_hat.reshape(B, T * self.latent_stride, D // self.latent_stride)
+ y_hat = self.denormalize_z(y_hat)
+ return y_hat
+
+ def forward(
+ self,
+ z: torch.Tensor,
+ t: torch.Tensor,
+ drop_cond_rate: float = 0.0,
+ drop_vae_rate: float = 0.0,
+ sigma: float = 1e-4,
+ ):
+ z = self.normalize_z(z)
+ B, T, D = z.shape
+ if self.latent_stride > 1:
+ z = z.reshape(B, T // self.latent_stride, D * self.latent_stride)
+
+ prior, ae_loss = self.autoencoder(z, drop_vae_rate=drop_vae_rate)
+
+ if drop_cond_rate > 0.0:
+ to_drop = torch.rand(prior.shape[0], device=prior.device) < drop_cond_rate
+ prior[to_drop] = 0.0
+
+ x0 = torch.randn_like(z)
+ x1 = z
+
+ flow_target = x1 - (1 - sigma) * x0
+
+ alpha = (1 - (1 - sigma) * t).view(-1, 1, 1)
+ xt = alpha * x0 + t.view(-1, 1, 1) * x1
+
+ pred = self.flow_net(
+ xt.transpose(1, 2),
+ prior.transpose(1, 2),
+ t,
+ )
+
+ flow_loss = nn.functional.mse_loss(flow_target.transpose(1, 2), pred)
+
+ return flow_loss, ae_loss, prior
diff --git a/codec/models/patchvae/modules.py b/codec/models/patchvae/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f9726a1f97c4b7435af954706c340cb369b567f
--- /dev/null
+++ b/codec/models/patchvae/modules.py
@@ -0,0 +1,396 @@
+from random import random
+from typing import Literal
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from vector_quantize_pytorch import FSQ
+
+from zcodec.models.components.transformer import TransformerBlock
+
+
+class AdaLayerNormScale(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.linear = nn.Linear(dim, dim * 3)
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False)
+
+ def forward(self, x, c):
+ x = self.norm(x)
+ scale, bias, gate = self.linear(F.silu(c)).chunk(3, dim=1)
+ shape = x.shape[0] + [1] * (x.dim() - 2) + x.shape[-1]
+ scale, bias, gate = map(lambda x: x.view(*shape), (scale, bias, gate))
+ x = x * (1 + scale) + bias
+ return x, gate
+
+
+class GaussianFourierTimeEmbedding(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(dim), requires_grad=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x[:, None] * self.weight[None, :] * 2 * torch.pi
+ x = torch.cat((torch.sin(x), torch.cos(x)), dim=1)
+ return x
+
+
+LAYER_FACTORIES = {}
+
+
+def register_flow_layer_factory(name):
+ def decorator(fn):
+ LAYER_FACTORIES[name] = fn
+ return fn
+
+ return decorator
+
+
+@register_flow_layer_factory("convnext")
+def SimpleConvNextFactory(dim: int, i: int, n_layer: int, is_causal: bool = False):
+ return ConvNeXtBlock(dim, elementwise_affine_ln=False, is_causal=is_causal)
+
+
+@register_flow_layer_factory("mlp")
+def MLP(dim: int, i: int, n_layer: int, is_causal: bool = False):
+ return AdaLNMLP(dim)
+
+
+@register_flow_layer_factory("sa_transformer")
+def SelfAttentionTransformer(dim: int, i: int, n_layer: int, is_causal: bool = False):
+ return TransformerBlock(dim, 64, elementwise_affine_ln=False, is_causal=is_causal)
+
+
+def init_weights(m: nn.Module):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+
+def init_adaln_weights(m: nn.Module):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.zeros_(m.bias)
+
+
+def modulate(x, scale, shift):
+ return x * (1 + scale[:, None]) + shift[:, None]
+
+
+class AdaLNFlowPredictor(nn.Module):
+ def __init__(
+ self,
+ feat_dim: int,
+ dim: int,
+ n_layer: int,
+ layer_factory: str,
+ cond_dim: int | None = None,
+ is_causal: bool = False,
+ ):
+ super().__init__()
+
+ layer_factory = LAYER_FACTORIES[layer_factory]
+ self.layers = nn.ModuleList(
+ [
+ layer_factory(dim, i, n_layer, is_causal=is_causal)
+ for i in range(n_layer)
+ ]
+ )
+ if cond_dim is None:
+ cond_dim = feat_dim
+ self.initial_proj = nn.Linear(feat_dim + cond_dim, dim)
+ self.adaln_proj = nn.ModuleList([nn.Linear(dim, dim * 3) for _ in self.layers])
+ self.final_adaln_proj = nn.Linear(dim, dim * 2)
+ self.out_proj = nn.Linear(dim, feat_dim)
+ self.final_norm = nn.LayerNorm(dim, elementwise_affine=False)
+ self.time_emb = GaussianFourierTimeEmbedding(dim // 2)
+
+ self.apply(init_weights)
+ for l in self.adaln_proj:
+ init_adaln_weights(l)
+ init_adaln_weights(self.final_adaln_proj)
+
+ def forward(
+ self,
+ x_t: torch.Tensor,
+ x_mu: torch.Tensor,
+ t: torch.Tensor,
+ ):
+ x_t, x_mu = map(lambda x: x.transpose(1, 2), (x_t, x_mu))
+ x = self.initial_proj(torch.cat((x_t, x_mu), dim=-1)).transpose(1, 2)
+
+ t_emb = self.time_emb(t)
+
+ for i, (l, adaln) in enumerate(zip(self.layers, self.adaln_proj)):
+ scale, shift, gate = F.silu(adaln(t_emb)).chunk(3, dim=1)
+ x = l(x, scale_shift=(scale, shift), gate=gate)
+
+ scale, shift = F.silu(self.final_adaln_proj(t_emb)).chunk(2, dim=1)
+ x = self.final_norm(x.transpose(1, 2))
+ x = modulate(x, scale, shift)
+
+ x = self.out_proj(x).transpose(1, 2)
+
+ return x
+
+
+class AdaLNMLP(nn.Module):
+ def __init__(self, hidden_dim):
+ super().__init__()
+ self.hidden_dim = hidden_dim
+
+ self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False)
+ self.mlp = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
+ )
+
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_dim, 4 * hidden_dim, bias=True)
+ )
+
+ def forward(self, x, scale_shift, gate):
+ x = x.transpose(-1, -2)
+ h = modulate(self.in_ln(x), *scale_shift)
+ h = self.mlp(h)
+ return (x + gate[:, None] * h).transpose(-1, -2)
+
+
+class ConvNeXtBlock(nn.Module):
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
+
+ Args:
+ dim (int): Number of input channels.
+ intermediate_dim (int): Dimensionality of the intermediate layer.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional LayerNorm. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int | None = None,
+ layer_scale_init_value: float = 0.0,
+ elementwise_affine_ln: bool = True,
+ is_causal: bool = False,
+ ):
+ super().__init__()
+ intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3
+ self.dwconv = nn.Conv1d(
+ dim, dim, kernel_size=7, padding=0 if is_causal else 3, groups=dim
+ ) # depthwise conv
+ self.norm = nn.LayerNorm(
+ dim, eps=1e-6, elementwise_affine=elementwise_affine_ln
+ )
+ self.pwconv1 = nn.Linear(
+ dim, intermediate_dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.is_causal = is_causal
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
+ gate: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ residual = x
+ if self.is_causal:
+ x = torch.nn.functional.pad(x, (6, 0))
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
+ x = self.norm(x)
+ if scale_shift is not None:
+ scale, shift = scale_shift
+ x = x * scale[:, None] + shift[:, None]
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ if gate is not None:
+ x = gate[:, None] * x
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
+
+ x = residual + x
+ return x
+
+
+class ConvNextNet(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ n_layers: int,
+ intermediate_dim: int | None = None,
+ is_causal: bool = False,
+ ):
+ super().__init__()
+ self.net = nn.Sequential(
+ *[
+ ConvNeXtBlock(dim, intermediate_dim, is_causal=is_causal)
+ for _ in range(n_layers)
+ ]
+ )
+
+ def forward(self, x):
+ return self.net(x.transpose(1, 2)).transpose(1, 2)
+
+
+def convnext_factory(dim, n_layers, is_causal=False):
+ return ConvNextNet(dim, n_layers, is_causal=is_causal)
+
+
+def convnextformer_factory(
+ dim, n_layers, n_convnext_per_transformer_block, is_causal=False
+):
+ layers = []
+ for i in range(0, n_layers, n_convnext_per_transformer_block + 1):
+ layers.append(
+ ConvNextNet(dim, n_convnext_per_transformer_block, is_causal=is_causal)
+ )
+ layers.append(TransformerBlock(dim, 64, is_causal=is_causal))
+ return nn.Sequential(*layers)
+
+
+class AutoEncoder(nn.Module):
+ def __init__(
+ self,
+ feat_dim: int,
+ hidden_dim: int,
+ num_layers: int,
+ net_factory: Literal["convnext", "convnextformer_decoder", "convnextformer"],
+ out_dim: int | None = None,
+ convnextformer_num_conv_per_transformer: int = 3,
+ causal_transformer: bool = False,
+ bottleneck_size: int | None = None,
+ vae: bool = False,
+ is_causal: bool = False,
+ ):
+ super().__init__()
+
+ self.embed = nn.Linear(feat_dim, hidden_dim)
+ if out_dim is None:
+ out_dim = feat_dim
+ self.unembed = nn.Linear(hidden_dim, out_dim)
+
+ if net_factory == "convnext":
+ self.encoder_net = convnext_factory(
+ hidden_dim, num_layers, is_causal=is_causal
+ )
+ self.decoder_net = convnext_factory(
+ hidden_dim, num_layers, is_causal=is_causal
+ )
+ elif net_factory == "convnextformer_decoder":
+ self.encoder_net = convnext_factory(
+ hidden_dim, num_layers, is_causal=is_causal
+ )
+ self.decoder_net = convnextformer_factory(
+ hidden_dim,
+ num_layers,
+ convnextformer_num_conv_per_transformer,
+ is_causal=is_causal,
+ )
+ elif net_factory == "convnextformer":
+ self.encoder_net = convnextformer_factory(
+ hidden_dim,
+ num_layers,
+ convnextformer_num_conv_per_transformer,
+ is_causal=is_causal,
+ )
+ self.decoder_net = convnextformer_factory(
+ hidden_dim,
+ num_layers,
+ convnextformer_num_conv_per_transformer,
+ is_causal=is_causal,
+ )
+
+ self.bottleneck = (
+ nn.Linear(hidden_dim, bottleneck_size * (1 + vae))
+ if bottleneck_size is not None
+ else nn.Identity()
+ )
+ self.unbottleneck = (
+ nn.Linear(bottleneck_size, hidden_dim)
+ if bottleneck_size is not None
+ else nn.Identity()
+ )
+ self.vae = vae
+
+ def reparameterize(
+ self,
+ mu: torch.Tensor,
+ logvar: torch.Tensor,
+ deterministic: bool = False,
+ drop_vae_rate: float = 0.0,
+ ) -> torch.Tensor:
+ logvar = torch.clamp(logvar, -30.0, 20.0)
+ std = torch.exp(0.5 * logvar)
+ if drop_vae_rate > 0.0:
+ to_drop = torch.rand(std.shape[0], device=std.device) < drop_vae_rate
+ eps = torch.randn_like(std)
+ eps[to_drop] = 0.0
+ else:
+ if deterministic:
+ eps = torch.zeros_like(std)
+ else:
+ eps = torch.randn_like(std)
+ return mu + eps * std
+
+ def kl_divergence(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
+ kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
+ return kl.sum(dim=-1).mean()
+
+ def forward(self, x: torch.Tensor, drop_vae_rate: float = 0.0) -> torch.Tensor:
+ # Encode
+ x = self.embed(x)
+ x = self.encoder_net(x)
+ x = self.bottleneck(x)
+ if self.vae:
+ mu, logvar = x.chunk(2, dim=-1)
+ loss = {
+ "kl_div": self.kl_divergence(mu, logvar),
+ "_mu_mean": mu.mean(),
+ "_mu_std": mu.std(),
+ "_logvar_mean": logvar.mean(),
+ "_logvar_std": logvar.std(),
+ }
+ x = self.reparameterize(
+ mu,
+ logvar,
+ drop_vae_rate=drop_vae_rate,
+ )
+ else:
+ loss = {}
+
+ # Decode
+ x = self.unbottleneck(x)
+ x = self.decoder_net(x)
+ x = self.unembed(x)
+
+ return x, loss
+
+ def encode(self, x: torch.Tensor, deterministic: bool = False):
+ x = self.embed(x)
+ x = self.encoder_net(x)
+ x = self.bottleneck(x)
+
+ if self.vae:
+ x = self.reparameterize(*x.chunk(2, dim=-1), deterministic=deterministic)
+ return x
+
+ def decode(
+ self,
+ latent: torch.Tensor | None = None,
+ ):
+ x = self.unbottleneck(latent)
+ x = self.decoder_net(x)
+ x = self.unembed(x)
+ return x
diff --git a/codec/models/wavvae/__init__.py b/codec/models/wavvae/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/codec/models/wavvae/__pycache__/__init__.cpython-312.pyc b/codec/models/wavvae/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bf1547a21222ae0bee9ea4841c041f1f2b27239
Binary files /dev/null and b/codec/models/wavvae/__pycache__/__init__.cpython-312.pyc differ
diff --git a/codec/models/wavvae/__pycache__/discriminators.cpython-312.pyc b/codec/models/wavvae/__pycache__/discriminators.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..914cab529f09300e46c68e097f707acaaaf43015
Binary files /dev/null and b/codec/models/wavvae/__pycache__/discriminators.cpython-312.pyc differ
diff --git a/codec/models/wavvae/__pycache__/heads.cpython-312.pyc b/codec/models/wavvae/__pycache__/heads.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2547e2211f7d5a72bc106c8fb715a8ee6e305018
Binary files /dev/null and b/codec/models/wavvae/__pycache__/heads.cpython-312.pyc differ
diff --git a/codec/models/wavvae/__pycache__/layers.cpython-312.pyc b/codec/models/wavvae/__pycache__/layers.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca73e82c8c07af3678ca2cda950dd1c28aabab06
Binary files /dev/null and b/codec/models/wavvae/__pycache__/layers.cpython-312.pyc differ
diff --git a/codec/models/wavvae/__pycache__/loss.cpython-312.pyc b/codec/models/wavvae/__pycache__/loss.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c507c5fb6931e3c6cfafae3317a4b47cefb3314
Binary files /dev/null and b/codec/models/wavvae/__pycache__/loss.cpython-312.pyc differ
diff --git a/codec/models/wavvae/__pycache__/model.cpython-312.pyc b/codec/models/wavvae/__pycache__/model.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef4754059a3498cbc4eddf81c248b952f086bf9c
Binary files /dev/null and b/codec/models/wavvae/__pycache__/model.cpython-312.pyc differ
diff --git a/codec/models/wavvae/__pycache__/modules.cpython-312.pyc b/codec/models/wavvae/__pycache__/modules.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6638610c2e36aaefc76be800057ba4dd37d94d7c
Binary files /dev/null and b/codec/models/wavvae/__pycache__/modules.cpython-312.pyc differ
diff --git a/codec/models/wavvae/__pycache__/spectral_ops.cpython-312.pyc b/codec/models/wavvae/__pycache__/spectral_ops.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f25c9a37e688787b91530c42cec3021c1a519a2
Binary files /dev/null and b/codec/models/wavvae/__pycache__/spectral_ops.cpython-312.pyc differ
diff --git a/codec/models/wavvae/dataset.py b/codec/models/wavvae/dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..6fab0e0668f1f54027eeebfb16ae22f61c881418
--- /dev/null
+++ b/codec/models/wavvae/dataset.py
@@ -0,0 +1,84 @@
+from dataclasses import dataclass
+
+import numpy as np
+import torch
+import torchaudio
+from pytorch_lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset
+
+torch.set_num_threads(1)
+
+
+@dataclass
+class WavVAEDataConfig:
+ filelist_path: str
+ sampling_rate: int
+ num_samples: int
+ batch_size: int
+ num_workers: int
+
+
+class WavVAEDataModule(LightningDataModule):
+ def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig):
+ super().__init__()
+ self.train_config = train_params
+ self.val_config = val_params
+
+ def _get_dataloder(self, cfg: DataConfig, train: bool):
+ dataset = WavVAEDataset(cfg, train=train)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=cfg.batch_size,
+ num_workers=cfg.num_workers,
+ shuffle=train,
+ pin_memory=True,
+ )
+ return dataloader
+
+ def train_dataloader(self) -> DataLoader:
+ return self._get_dataloder(self.train_config, train=True)
+
+ def val_dataloader(self) -> DataLoader:
+ return self._get_dataloder(self.val_config, train=False)
+
+
+class WavVAEDataset(Dataset):
+ def __init__(self, cfg: DataConfig, train: bool):
+ with open(cfg.filelist_path) as f:
+ self.filelist = f.read().splitlines()
+ self.sampling_rate = cfg.sampling_rate
+ self.num_samples = cfg.num_samples
+ self.train = train
+
+ def __len__(self) -> int:
+ return len(self.filelist)
+
+ def __getitem__(self, index: int) -> torch.Tensor:
+ audio_path = self.filelist[index]
+ y, sr = torchaudio.load(audio_path)
+ if y.size(0) > 1:
+ # mix to mono
+ y = y.mean(dim=0, keepdim=True)
+ gain = np.random.uniform(-1, -6) if self.train else -3
+ y, _ = torchaudio.sox_effects.apply_effects_tensor(
+ y, sr, [["norm", f"{gain:.2f}"]]
+ )
+ try:
+ if sr != self.sampling_rate:
+ y = torchaudio.functional.resample(
+ y, orig_freq=sr, new_freq=self.sampling_rate
+ )
+ except:
+ print(audio_path, y.shape)
+ if y.size(-1) < self.num_samples:
+ pad_length = self.num_samples - y.size(-1)
+ padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
+ y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
+ elif self.train:
+ start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
+ y = y[:, start : start + self.num_samples]
+ else:
+ # During validation, take always the first segment for determinism
+ y = y[:, : self.num_samples]
+
+ return y[0]
diff --git a/codec/models/wavvae/discriminators.py b/codec/models/wavvae/discriminators.py
new file mode 100755
index 0000000000000000000000000000000000000000..2c62574a78bc963eb35dedffd6ef13a96f4e4193
--- /dev/null
+++ b/codec/models/wavvae/discriminators.py
@@ -0,0 +1,211 @@
+from typing import List, Optional, Tuple
+
+import torch
+from einops import rearrange
+from torch import nn
+from torch.nn import Conv2d
+from torch.nn.utils import weight_norm
+from torchaudio.transforms import Spectrogram
+
+
+class MultiPeriodDiscriminator(nn.Module):
+ """
+ Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
+
+ Args:
+ periods (tuple[int]): Tuple of periods for each discriminator.
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
+ Defaults to None.
+ """
+
+ def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None):
+ super().__init__()
+ self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])
+
+ def forward(
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for d in self.discriminators:
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorP(nn.Module):
+ def __init__(
+ self,
+ period: int,
+ in_channels: int = 1,
+ kernel_size: int = 5,
+ stride: int = 3,
+ lrelu_slope: float = 0.1,
+ num_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.period = period
+ self.convs = nn.ModuleList(
+ [
+ weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
+ weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
+ ]
+ )
+ if num_embeddings is not None:
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024)
+ torch.nn.init.zeros_(self.emb.weight)
+
+ self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+ self.lrelu_slope = lrelu_slope
+
+ def forward(
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ x = x.unsqueeze(1)
+ fmap = []
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for i, l in enumerate(self.convs):
+ x = l(x)
+ x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
+ if i > 0:
+ fmap.append(x)
+ if cond_embedding_id is not None:
+ emb = self.emb(cond_embedding_id)
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
+ else:
+ h = 0
+ x = self.conv_post(x)
+ fmap.append(x)
+ x += h
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiResolutionDiscriminator(nn.Module):
+ def __init__(
+ self,
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
+ num_embeddings: Optional[int] = None,
+ ):
+ """
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
+
+ Args:
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
+ Defaults to None.
+ """
+
+ super().__init__()
+ self.discriminators = nn.ModuleList(
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
+ )
+
+ def forward(
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+
+ for d in self.discriminators:
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorR(nn.Module):
+ def __init__(
+ self,
+ window_length: int,
+ num_embeddings: Optional[int] = None,
+ channels: int = 32,
+ hop_factor: float = 0.25,
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
+ ):
+ super().__init__()
+ self.window_length = window_length
+ self.hop_factor = hop_factor
+ self.spec_fn = Spectrogram(
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
+ )
+ n_fft = window_length // 2 + 1
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
+ self.bands = bands
+ convs = lambda: nn.ModuleList(
+ [
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
+ ]
+ )
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
+
+ if num_embeddings is not None:
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
+ torch.nn.init.zeros_(self.emb.weight)
+
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
+
+ def spectrogram(self, x):
+ # Remove DC offset
+ x = x - x.mean(dim=-1, keepdims=True)
+ # Peak normalize the volume of input audio
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
+ x = self.spec_fn(x)
+ x = torch.view_as_real(x)
+ x = rearrange(x, "b f t c -> b c t f")
+ # Split into bands
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
+ return x_bands
+
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
+ x_bands = self.spectrogram(x)
+ fmap = []
+ x = []
+ for band, stack in zip(x_bands, self.band_convs):
+ for i, layer in enumerate(stack):
+ band = layer(band)
+ band = torch.nn.functional.leaky_relu(band, 0.1)
+ if i > 0:
+ fmap.append(band)
+ x.append(band)
+ x = torch.cat(x, dim=-1)
+ if cond_embedding_id is not None:
+ emb = self.emb(cond_embedding_id)
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
+ else:
+ h = 0
+ x = self.conv_post(x)
+ fmap.append(x)
+ x += h
+
+ return x, fmap
diff --git a/codec/models/wavvae/experiment.py b/codec/models/wavvae/experiment.py
new file mode 100755
index 0000000000000000000000000000000000000000..b28b04f643122b019e912540f228c8ed20be9eeb
--- /dev/null
+++ b/codec/models/wavvae/experiment.py
@@ -0,0 +1,3 @@
+
+
+
diff --git a/codec/models/wavvae/heads.py b/codec/models/wavvae/heads.py
new file mode 100755
index 0000000000000000000000000000000000000000..92c576aa3f8a72f2b2b148d2d162eab6e838ac31
--- /dev/null
+++ b/codec/models/wavvae/heads.py
@@ -0,0 +1,194 @@
+from typing import Optional
+
+import torch
+from einops import rearrange
+from torch import nn
+from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
+
+from .modules import symexp
+from .spectral_ops import IMDCT, ISTFT
+
+
+class FourierHead(nn.Module):
+ """Base class for inverse fourier modules."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ raise NotImplementedError("Subclasses must implement the forward method.")
+
+
+class LinearNoBiasHead(FourierHead):
+ def __init__(self, dim: int, hop_length: int, n_fft: int):
+ super().__init__()
+ self.pre_head = nn.Linear(dim, n_fft + 2)
+ self.head = nn.Linear(n_fft + 2, hop_length, bias=False)
+
+ def forward(self, x):
+ y = self.pre_head(x)
+ y = self.head(y).clamp(min=-1.0, max=1.0)
+ B, _, _ = y.shape
+ return y.reshape(B, -1)
+
+
+class ISTFTHead(FourierHead):
+ """
+ ISTFT Head module for predicting STFT complex coefficients.
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
+ the resolution of the input features.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
+ super().__init__()
+ out_dim = n_fft + 2
+ self.out = torch.nn.Linear(dim, out_dim)
+ self.hop_length = hop_length
+ self.istft = ISTFT(
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the ISTFTHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x).transpose(1, 2)
+ mag, p = x.chunk(2, dim=1)
+ mag = torch.exp(mag)
+ mag = torch.clip(
+ mag, max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ # wrapping happens here. These two lines produce real and imaginary value
+ x = torch.cos(p)
+ y = torch.sin(p)
+ # recalculating phase here does not produce anything new
+ # only costs time
+ # phase = torch.atan2(y, x)
+ # S = mag * torch.exp(phase * 1j)
+ # better directly produce the complex value
+ S = mag * (x + 1j * y)
+ audio = self.istft(S)
+ audio = nn.functional.pad(audio, (self.hop_length // 2, self.hop_length // 2))
+ return audio
+
+
+class IMDCTSymExpHead(FourierHead):
+ """
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ mdct_frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
+ based on perceptual scaling. Defaults to None.
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ mdct_frame_len: int,
+ padding: str = "same",
+ sample_rate: Optional[int] = None,
+ clip_audio: bool = False,
+ ):
+ super().__init__()
+ out_dim = mdct_frame_len // 2
+ self.out = nn.Linear(dim, out_dim)
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
+ self.clip_audio = clip_audio
+
+ if sample_rate is not None:
+ # optionally init the last layer following mel-scale
+ m_max = _hz_to_mel(sample_rate // 2)
+ m_pts = torch.linspace(0, m_max, out_dim)
+ f_pts = _mel_to_hz(m_pts)
+ scale = 1 - (f_pts / f_pts.max())
+
+ with torch.no_grad():
+ self.out.weight.mul_(scale.view(-1, 1))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the IMDCTSymExpHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x)
+ x = symexp(x)
+ x = torch.clip(
+ x, min=-1e2, max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ audio = self.imdct(x)
+ if self.clip_audio:
+ audio = torch.clip(x, min=-1.0, max=1.0)
+
+ return audio
+
+
+class IMDCTCosHead(FourierHead):
+ """
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
+
+ Args:
+ dim (int): Hidden dimension of the model.
+ mdct_frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ mdct_frame_len: int,
+ padding: str = "same",
+ clip_audio: bool = False,
+ ):
+ super().__init__()
+ self.clip_audio = clip_audio
+ self.out = nn.Linear(dim, mdct_frame_len)
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the IMDCTCosHead module.
+
+ Args:
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
+ L is the sequence length, and H denotes the model dimension.
+
+ Returns:
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+ """
+ x = self.out(x)
+ m, p = x.chunk(2, dim=2)
+ m = torch.exp(m).clip(
+ max=1e2
+ ) # safeguard to prevent excessively large magnitudes
+ audio = self.imdct(m * torch.cos(p))
+ if self.clip_audio:
+ audio = torch.clip(x, min=-1.0, max=1.0)
+ return audio
diff --git a/codec/models/wavvae/helpers.py b/codec/models/wavvae/helpers.py
new file mode 100755
index 0000000000000000000000000000000000000000..3d303010352ad59dde2996605f124128ee17db36
--- /dev/null
+++ b/codec/models/wavvae/helpers.py
@@ -0,0 +1,71 @@
+import matplotlib
+import numpy as np
+import torch
+from matplotlib import pyplot as plt
+from pytorch_lightning import Callback
+
+matplotlib.use("Agg")
+
+
+def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
+ """
+ Save a matplotlib figure to a numpy array.
+
+ Args:
+ fig (Figure): Matplotlib figure object.
+
+ Returns:
+ ndarray: Numpy array representing the figure.
+ """
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ return data
+
+
+def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
+ """
+ Plot a spectrogram and convert it to a numpy array.
+
+ Args:
+ spectrogram (ndarray): Spectrogram data.
+
+ Returns:
+ ndarray: Numpy array representing the plotted spectrogram.
+ """
+ spectrogram = spectrogram.astype(np.float32)
+ fig, ax = plt.subplots(figsize=(12, 3))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
+ plt.colorbar(im, ax=ax)
+ plt.xlabel("Frames")
+ plt.ylabel("Channels")
+ plt.tight_layout()
+
+ fig.canvas.draw()
+ data = save_figure_to_numpy(fig)
+ plt.close()
+ return data
+
+
+class GradNormCallback(Callback):
+ """
+ Callback to log the gradient norm.
+ """
+
+ def on_after_backward(self, trainer, model):
+ model.log("grad_norm", gradient_norm(model))
+
+
+def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
+ """
+ Compute the gradient norm.
+
+ Args:
+ model (Module): PyTorch model.
+ norm_type (float, optional): Type of the norm. Defaults to 2.0.
+
+ Returns:
+ Tensor: Gradient norm.
+ """
+ grads = [p.grad for p in model.parameters() if p.grad is not None]
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
+ return total_norm
diff --git a/codec/models/wavvae/layers.py b/codec/models/wavvae/layers.py
new file mode 100755
index 0000000000000000000000000000000000000000..686bcdaa25fcc9f41a42aac6dfcb769cd1517161
--- /dev/null
+++ b/codec/models/wavvae/layers.py
@@ -0,0 +1,282 @@
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn.utils.parametrizations import weight_norm
+
+class VocosDecoder(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int,
+ num_layers: int,
+ ):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ self.convnext = nn.ModuleList(
+ [
+ ConvNeXtBlock(
+ dim=dim,
+ intermediate_dim=intermediate_dim,
+ layer_scale_init_value=0.0,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x)
+ x = x.transpose(1, 2)
+ for conv_block in self.convnext:
+ x = conv_block(x)
+ x = self.final_layer_norm(x.transpose(1, 2))
+ return x
+
+class ConvNeXtBlock(nn.Module):
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
+
+ Args:
+ dim (int): Number of input channels.
+ intermediate_dim (int): Dimensionality of the intermediate layer.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional LayerNorm. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int | None = None,
+ layer_scale_init_value: float = 0.0,
+ elementwise_affine_ln: bool = True,
+ ):
+ super().__init__()
+ intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3
+ self.dwconv = nn.Conv1d(
+ dim, dim, kernel_size=7, padding=3, groups=dim
+ ) # depthwise conv
+ self.norm = nn.LayerNorm(
+ dim, eps=1e-6, elementwise_affine=elementwise_affine_ln
+ )
+ self.pwconv1 = nn.Linear(
+ dim, intermediate_dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
+ gate: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ residual = x
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
+ x = self.norm(x)
+ if scale_shift is not None:
+ scale, shift = scale_shift
+ x = x * scale[:, None] + shift[:, None]
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ if gate is not None:
+ x = gate[:, None] * x
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
+
+ x = residual + x
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ d_model=32,
+ strides=[2, 4, 4, 8],
+ depthwise=False,
+ ):
+ super().__init__()
+ layers = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
+ for stride in strides:
+ d_model *= 2
+ groups = d_model // 2 if depthwise else 1
+ layers += [EncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
+ groups = d_model if depthwise else 1
+ layers += [
+ WNConv1d(d_model, d_model, kernel_size=7, padding=3, groups=groups),
+ ]
+ self.block = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ input_channel,
+ channels,
+ rates,
+ noise=False,
+ depthwise=False,
+ d_out=1,
+ ):
+ super().__init__()
+ if depthwise:
+ layers = [
+ WNConv1d(
+ input_channel,
+ input_channel,
+ kernel_size=7,
+ padding=3,
+ groups=input_channel,
+ ),
+ WNConv1d(input_channel, channels, kernel_size=1),
+ ]
+ else:
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
+
+ for i, stride in enumerate(rates):
+ input_dim = channels // 2**i
+ output_dim = channels // 2 ** (i + 1)
+ groups = output_dim if depthwise else 1
+ layers.append(
+ DecoderBlock(input_dim, output_dim, stride, noise, groups=groups)
+ )
+
+ layers += [
+ Snake1d(output_dim),
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
+ nn.Tanh(),
+ ]
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.model(x)
+ return x
+
+
+class ResidualUnit(nn.Module):
+ def __init__(self, dim=16, dilation=1, kernel=7, groups=1):
+ super().__init__()
+ pad = ((kernel - 1) * dilation) // 2
+ self.block = nn.Sequential(
+ Snake1d(dim),
+ WNConv1d(
+ dim,
+ dim,
+ kernel_size=kernel,
+ dilation=dilation,
+ padding=pad,
+ groups=groups,
+ ),
+ Snake1d(dim),
+ WNConv1d(dim, dim, kernel_size=1),
+ )
+
+ def forward(self, x):
+ y = self.block(x)
+ pad = (x.shape[-1] - y.shape[-1]) // 2
+ if pad > 0:
+ x = x[..., pad:-pad]
+ return x + y
+
+
+class EncoderBlock(nn.Module):
+ def __init__(self, output_dim=16, input_dim=None, stride=1, groups=1):
+ super().__init__()
+ input_dim = input_dim or output_dim // 2
+ self.block = nn.Sequential(
+ ResidualUnit(input_dim, dilation=1, groups=groups),
+ ResidualUnit(input_dim, dilation=3, groups=groups),
+ ResidualUnit(input_dim, dilation=9, groups=groups),
+ Snake1d(input_dim),
+ WNConv1d(
+ input_dim,
+ output_dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2),
+ ),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class NoiseBlock(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.linear = WNConv1d(dim, dim, kernel_size=1, bias=False)
+
+ def forward(self, x):
+ B, C, T = x.shape
+ noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
+ h = self.linear(x)
+ n = noise * h
+ x = x + n
+ return x
+
+
+class DecoderBlock(nn.Module):
+ def __init__(self, input_dim=16, output_dim=8, stride=1, noise=False, groups=1):
+ super().__init__()
+ layers = [
+ Snake1d(input_dim),
+ WNConvTranspose1d(
+ input_dim,
+ output_dim,
+ kernel_size=2 * stride,
+ stride=stride,
+ padding=math.ceil(stride / 2),
+ output_padding=stride % 2,
+ ),
+ ]
+ if noise:
+ layers.append(NoiseBlock(output_dim))
+ layers.extend(
+ [
+ ResidualUnit(output_dim, dilation=1, groups=groups),
+ ResidualUnit(output_dim, dilation=3, groups=groups),
+ ResidualUnit(output_dim, dilation=9, groups=groups),
+ ]
+ )
+ self.block = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.block(x)
+
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+@torch.jit.script
+def snake(x, alpha):
+ shape = x.shape
+ x = x.reshape(shape[0], shape[1], -1)
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
+ x = x.reshape(shape)
+ return x
+
+
+class Snake1d(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
+
+ def forward(self, x):
+ return snake(x, self.alpha)
diff --git a/codec/models/wavvae/loss.py b/codec/models/wavvae/loss.py
new file mode 100755
index 0000000000000000000000000000000000000000..ec3268216302151bf20351116ad7c3d886a8aeac
--- /dev/null
+++ b/codec/models/wavvae/loss.py
@@ -0,0 +1,142 @@
+from typing import List, Optional, Tuple
+
+import torch
+import torchaudio
+from torch import nn
+
+from .modules import safe_log
+
+
+class MelSpecReconstructionLoss(nn.Module):
+ """
+ L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
+ """
+
+ def __init__(
+ self,
+ sample_rate: int = 24000,
+ n_fft: int | None = None,
+ hop_length: int = 256,
+ n_mels: int = 100,
+ f_min: int = 0,
+ f_max: Optional[int] = None,
+ ):
+ super().__init__()
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
+ sample_rate=sample_rate,
+ n_fft=hop_length * 4 if n_fft is None else n_fft,
+ hop_length=hop_length,
+ n_mels=n_mels,
+ center=True,
+ power=1,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
+ def forward(self, y_hat, y) -> torch.Tensor:
+ """
+ Args:
+ y_hat (Tensor): Predicted audio waveform.
+ y (Tensor): Ground truth audio waveform.
+
+ Returns:
+ Tensor: L1 loss between the mel-scaled magnitude spectrograms.
+ """
+ # B, C, Th = y_hat.shape
+ # B, C, T = y.shape
+ # crop = (Th - T) // 2
+ mel_hat = safe_log(self.mel_spec(y_hat))
+ # mel_hat = safe_log(self.mel_spec(y_hat[..., crop:-crop]))
+ # mel = safe_log(self.mel_spec(y[..., crop:-crop]))
+ mel = safe_log(self.mel_spec(y))
+
+ loss = torch.nn.functional.l1_loss(mel, mel_hat)
+
+ return loss
+
+
+class GeneratorLoss(nn.Module):
+ """
+ Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
+ """
+
+ def forward(
+ self, disc_outputs: List[torch.Tensor]
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """
+ Args:
+ disc_outputs (List[Tensor]): List of discriminator outputs.
+
+ Returns:
+ Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
+ the sub-discriminators
+ """
+ loss = torch.zeros(
+ 1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype
+ )
+ gen_losses = []
+ for dg in disc_outputs:
+ l = torch.mean(torch.clamp(1 - dg, min=0))
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
+
+
+class DiscriminatorLoss(nn.Module):
+ """
+ Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
+ """
+
+ def forward(
+ self,
+ disc_real_outputs: List[torch.Tensor],
+ disc_generated_outputs: List[torch.Tensor],
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
+ """
+ Args:
+ disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
+ disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
+
+ Returns:
+ Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
+ the sub-discriminators for real outputs, and a list of
+ loss values for generated outputs.
+ """
+ loss = torch.zeros(
+ 1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype
+ )
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean(torch.clamp(1 - dr, min=0))
+ g_loss = torch.mean(torch.clamp(1 + dg, min=0))
+ loss += r_loss + g_loss
+ r_losses.append(r_loss)
+ g_losses.append(g_loss)
+
+ return loss, r_losses, g_losses
+
+
+class FeatureMatchingLoss(nn.Module):
+ """
+ Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
+ """
+
+ def forward(
+ self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
+ ) -> torch.Tensor:
+ """
+ Args:
+ fmap_r (List[List[Tensor]]): List of feature maps from real samples.
+ fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
+
+ Returns:
+ Tensor: The calculated feature matching loss.
+ """
+ loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss
diff --git a/codec/models/wavvae/model.py b/codec/models/wavvae/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..b5562a30907bd82d36deb388587024e9ca94a2db
--- /dev/null
+++ b/codec/models/wavvae/model.py
@@ -0,0 +1,140 @@
+import json
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Literal
+
+import torch
+from safetensors.torch import load_file
+from torch import nn
+
+from .heads import ISTFTHead, LinearNoBiasHead
+from .layers import Encoder, VocosDecoder
+from .modules import ConvNeXtBlock
+
+
+@dataclass
+class WavVAEConfig:
+ conv_dim: int = 48
+ latent_dim: int = 32
+ decoder_hidden_dim: int = 768
+ decoder_intermediate_dim: int = 1536
+ decoder_num_layers: int = 8
+ n_fft: int = 1024
+ hop_length: int = 256
+ padding: str = "center"
+ head_type: Literal["istft", "linear"] = "istft"
+ strides: list[int] = field(default_factory=lambda: [2, 4, 4, 8])
+ learnable_pre_norm: bool = False
+ sampling_rate: int = 24000
+
+
+class WavVAE(nn.Module):
+ def __init__(self, cfg: WavVAEConfig):
+ super().__init__()
+ self.conv_encoder = Encoder(cfg.conv_dim, strides=cfg.strides, depthwise=True)
+ conv_final_dim = cfg.conv_dim * 2 ** len(cfg.strides)
+ self.bottleneck = nn.Linear(conv_final_dim, cfg.latent_dim * 2)
+ self.unbottleneck = nn.Linear(cfg.latent_dim, cfg.decoder_hidden_dim)
+ self.latent_norm = nn.LayerNorm(conv_final_dim)
+ self.vocos_decoder = VocosDecoder(
+ cfg.decoder_hidden_dim,
+ cfg.decoder_intermediate_dim,
+ cfg.decoder_num_layers,
+ )
+ if cfg.head_type == "istft":
+ self.head = ISTFTHead(
+ cfg.decoder_hidden_dim,
+ cfg.n_fft,
+ cfg.hop_length,
+ padding=cfg.padding,
+ )
+ elif cfg.head_type == "linear":
+ self.head = LinearNoBiasHead(
+ cfg.decoder_hidden_dim,
+ cfg.hop_length,
+ cfg.n_fft,
+ )
+
+ self._sampling_rate = cfg.sampling_rate
+ self._strides = cfg.strides
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ @property
+ def sampling_rate(self) -> int:
+ return self._sampling_rate
+
+ @property
+ def hop_length(self) -> int:
+ hop_length = 1
+ for s in self._strides:
+ hop_length *= s
+ return hop_length
+
+ @property
+ def frame_rate(self) -> float:
+ return self.sampling_rate / self.hop_length
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str,
+ device: str = "cpu",
+ ):
+ if Path(pretrained_model_name_or_path).exists():
+ path = pretrained_model_name_or_path
+ else:
+ from huggingface_hub import snapshot_download
+
+ path = snapshot_download(pretrained_model_name_or_path)
+
+ with open(Path(path) / "config.json", "r") as f:
+ config = json.load(f)
+ config = WavVAEConfig(**config)
+ model = cls(config)
+ state_dict = load_file(
+ Path(path) / "model.st",
+ device=device,
+ )
+
+ model.load_state_dict(state_dict, assign=True)
+ return model
+
+ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
+ logvar = torch.clamp(logvar, -30.0, 20.0)
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ return mu + eps * std
+
+ def kl_divergence(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
+ kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
+ return kl.sum(dim=-1).mean()
+
+ def encode(self, audio: torch.Tensor) -> torch.Tensor:
+ y = self.conv_encoder(audio.unsqueeze(1)).transpose(1, 2)
+ y = self.latent_norm(y)
+ mu, logvar = self.bottleneck(y).chunk(2, dim=-1)
+ z = self.reparameterize(mu, logvar)
+ return z
+
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
+ y = self.unbottleneck(z)
+ y = self.vocos_decoder(y)
+ return self.head(y)
+
+ def forward(self, audio_input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ y = self.conv_encoder(audio_input.unsqueeze(1)).transpose(1, 2)
+ y = self.latent_norm(y)
+ mu, logvar = self.bottleneck(y).chunk(2, dim=-1)
+ kl_div = self.kl_divergence(mu, logvar)
+ z = self.reparameterize(mu, logvar)
+ y = self.unbottleneck(z)
+ y = self.vocos_decoder(y)
+ audio_output = self.head(y)
+
+ return audio_output, kl_div
diff --git a/codec/models/wavvae/modules.py b/codec/models/wavvae/modules.py
new file mode 100755
index 0000000000000000000000000000000000000000..af1d6db16e2f10cc9af7bcc64434e2c983d756b7
--- /dev/null
+++ b/codec/models/wavvae/modules.py
@@ -0,0 +1,213 @@
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+
+class ConvNeXtBlock(nn.Module):
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
+
+ Args:
+ dim (int): Number of input channels.
+ intermediate_dim (int): Dimensionality of the intermediate layer.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+ None means non-conditional LayerNorm. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int,
+ layer_scale_init_value: float,
+ adanorm_num_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.adanorm = adanorm_num_embeddings is not None
+ if adanorm_num_embeddings:
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
+ else:
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+
+ def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
+ residual = x
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
+ if self.adanorm:
+ assert cond_embedding_id is not None
+ x = self.norm(x, cond_embedding_id)
+ else:
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
+
+ x = residual + x
+ return x
+
+
+class AdaLayerNorm(nn.Module):
+ """
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
+
+ Args:
+ num_embeddings (int): Number of embeddings.
+ embedding_dim (int): Dimension of the embeddings.
+ """
+
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.dim = embedding_dim
+ self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
+ self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
+ torch.nn.init.ones_(self.scale.weight)
+ torch.nn.init.zeros_(self.shift.weight)
+
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
+ scale = self.scale(cond_embedding_id)
+ shift = self.shift(cond_embedding_id)
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
+ x = x * scale + shift
+ return x
+
+
+class ResBlock1(nn.Module):
+ """
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
+ but without upsampling layers.
+
+ Args:
+ dim (int): Number of input channels.
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
+ Defaults to (1, 3, 5).
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
+ Defaults to 0.1.
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+ Defaults to None.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ kernel_size: int = 3,
+ dilation: Tuple[int, int, int] = (1, 3, 5),
+ lrelu_slope: float = 0.1,
+ layer_scale_init_value: Optional[float] = None,
+ ):
+ super().__init__()
+ self.lrelu_slope = lrelu_slope
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=self.get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=self.get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ nn.Conv1d(
+ dim,
+ dim,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=self.get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
+ ]
+ )
+
+ self.gamma = nn.ParameterList(
+ [
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
+ if layer_scale_init_value is not None
+ else None,
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
+ if layer_scale_init_value is not None
+ else None,
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
+ if layer_scale_init_value is not None
+ else None,
+ ]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
+ xt = c1(xt)
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
+ xt = c2(xt)
+ if gamma is not None:
+ xt = gamma * xt
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+ @staticmethod
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
+ """
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
+
+ Args:
+ x (Tensor): Input tensor.
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
+
+ Returns:
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
+ """
+ return torch.log(torch.clip(x, min=clip_val))
+
+
+def symlog(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * torch.log1p(x.abs())
+
+
+def symexp(x: torch.Tensor) -> torch.Tensor:
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
diff --git a/codec/models/wavvae/spectral_ops.py b/codec/models/wavvae/spectral_ops.py
new file mode 100755
index 0000000000000000000000000000000000000000..a8eda1c8e18a32406aad40415f6a8bf60eb15fea
--- /dev/null
+++ b/codec/models/wavvae/spectral_ops.py
@@ -0,0 +1,192 @@
+import numpy as np
+import scipy
+import torch
+from torch import nn, view_as_real, view_as_complex
+
+
+class ISTFT(nn.Module):
+ """
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
+ See issue: https://github.com/pytorch/pytorch/issues/62323
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
+ The NOLA constraint is met as we trim padded samples anyway.
+
+ Args:
+ n_fft (int): Size of Fourier transform.
+ hop_length (int): The distance between neighboring sliding window frames.
+ win_length (int): The size of window frame and STFT filter.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.win_length = win_length
+ window = torch.hann_window(win_length)
+ self.register_buffer("window", window)
+
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
+
+ Args:
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
+ N is the number of frequency bins, and T is the number of time frames.
+
+ Returns:
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
+ """
+ if self.padding == "center":
+ # Fallback to pytorch native implementation
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
+ elif self.padding == "same":
+ pad = (self.win_length - self.hop_length) // 2
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
+ B, N, T = spec.shape
+
+ # Inverse FFT
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
+ ifft = ifft * self.window[None, :, None]
+
+ # Overlap and Add
+ output_size = (T - 1) * self.hop_length + self.win_length
+ y = torch.nn.functional.fold(
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
+ )[:, 0, 0, pad:-pad]
+
+ # Window envelope
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
+ window_envelope = torch.nn.functional.fold(
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
+ ).squeeze()[pad:-pad]
+
+ # Normalize
+ assert (window_envelope > 1e-11).all()
+ y = y / window_envelope
+
+ return y
+
+
+class MDCT(nn.Module):
+ """
+ Modified Discrete Cosine Transform (MDCT) module.
+
+ Args:
+ frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, frame_len: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.frame_len = frame_len
+ N = frame_len // 2
+ n0 = (N + 1) / 2
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
+ self.register_buffer("window", window)
+
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
+ # https://github.com/pytorch/pytorch/issues/71613
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
+
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
+
+ Args:
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
+ and T is the length of the audio.
+
+ Returns:
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
+ and N is the number of frequency bins.
+ """
+ if self.padding == "center":
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
+ elif self.padding == "same":
+ # hop_length is 1/2 frame_len
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
+ N = self.frame_len // 2
+ x = x * self.window.expand(x.shape)
+ X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
+ return torch.real(res) * np.sqrt(2)
+
+
+class IMDCT(nn.Module):
+ """
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
+
+ Args:
+ frame_len (int): Length of the MDCT frame.
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+ """
+
+ def __init__(self, frame_len: int, padding: str = "same"):
+ super().__init__()
+ if padding not in ["center", "same"]:
+ raise ValueError("Padding must be 'center' or 'same'.")
+ self.padding = padding
+ self.frame_len = frame_len
+ N = frame_len // 2
+ n0 = (N + 1) / 2
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
+ self.register_buffer("window", window)
+
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
+
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
+ """
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
+
+ Args:
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
+ L is the number of frames, and N is the number of frequency bins.
+
+ Returns:
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
+ """
+ B, L, N = X.shape
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
+ Y[..., :N] = X
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
+ y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
+ y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
+ result = y * self.window.expand(y.shape)
+ output_size = (1, (L + 1) * N)
+ audio = torch.nn.functional.fold(
+ result.transpose(1, 2),
+ output_size=output_size,
+ kernel_size=(1, self.frame_len),
+ stride=(1, self.frame_len // 2),
+ )[:, 0, 0, :]
+
+ if self.padding == "center":
+ pad = self.frame_len // 2
+ elif self.padding == "same":
+ pad = self.frame_len // 4
+ else:
+ raise ValueError("Padding must be 'center' or 'same'.")
+
+ audio = audio[:, pad:-pad]
+ return audio
diff --git a/codec/scripts/compare_codecs.py b/codec/scripts/compare_codecs.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbea2f8b7d4356897318ef2da5fa0cbf8c2b6a70
--- /dev/null
+++ b/codec/scripts/compare_codecs.py
@@ -0,0 +1,441 @@
+#!/usr/bin/env python3
+import argparse
+import json
+import os
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Dict, List, Optional, Any, Tuple
+
+import torch
+from torchaudio import load as ta_load
+from torchaudio.functional import resample as ta_resample
+import torchaudio
+
+# Your libs
+from zcodec.models import WavVAE, ZFlowAutoEncoder
+
+
+# -------------------------
+# Data structures
+# -------------------------
+
+
+@dataclass
+class DecodeParams:
+ num_steps: int = 10
+ cfg: float = 2.0
+
+
+@dataclass
+class ModelPairSpec:
+ name: str
+ wavvae_dir: str
+ zflowae_dir: str
+ decode: DecodeParams
+
+
+# -------------------------
+# Utilities
+# -------------------------
+
+
+def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]:
+ if path.is_file():
+ try:
+ with path.open("r", encoding="utf-8") as f:
+ return json.load(f)
+ except Exception:
+ return None
+ return None
+
+
+def read_config_any(checkpoint_dir: str) -> Dict[str, Any]:
+ """
+ Try to read config.json (or a few common fallbacks) from a checkpoint dir.
+ Returns {} if nothing could be parsed.
+ """
+ cand = [
+ Path(checkpoint_dir) / "config.json",
+ Path(checkpoint_dir)
+ / "config.yaml", # won't parse yaml here, we only display path
+ Path(checkpoint_dir) / "model_config.json",
+ ]
+ for p in cand:
+ if p.exists():
+ if p.suffix == ".json":
+ j = load_json_if_exists(p)
+ if j is not None:
+ return j
+ else:
+ # For YAML or unknown, just show filename rather than failing
+ return {"_config_file": str(p)}
+ return {}
+
+
+def sanitize_name(s: str) -> str:
+ return "".join(c if c.isalnum() or c in "-_." else "_" for c in s)
+
+
+def ensure_mono_and_resample(
+ wav: torch.Tensor, sr: int, target_sr: int
+) -> Tuple[torch.Tensor, int]:
+ """
+ wav: (channels, samples)
+ returns mono float32 in [-1,1], resampled to target_sr
+ """
+ if wav.ndim != 2:
+ raise ValueError(f"Expected 2D waveform (C, T), got shape {tuple(wav.shape)}")
+ # to mono
+ if wav.size(0) > 1:
+ wav = wav.mean(dim=0, keepdim=True)
+ # resample if needed
+ if sr != target_sr:
+ wav = ta_resample(wav, sr, target_sr)
+ sr = target_sr
+ return wav.to(torch.float32), sr
+
+
+def save_wav(path: Path, wav: torch.Tensor, sr: int):
+ path.parent.mkdir(parents=True, exist_ok=True)
+ # (C, T)
+ if wav.ndim == 1:
+ wav = wav.unsqueeze(0)
+ # Clamp to [-1,1]
+ wav = wav.clamp(-1, 1).contiguous().cpu()
+ torchaudio.save(
+ str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16
+ )
+
+
+# -------------------------
+# Core inference
+# -------------------------
+
+
+@torch.inference_mode()
+def reconstruct_full_pipeline(
+ wav_mono: torch.Tensor,
+ sr: int,
+ wavvae: WavVAE,
+ zflowae: ZFlowAutoEncoder,
+ decode_params: DecodeParams,
+ device: str,
+) -> torch.Tensor:
+ """
+ Full path: audio -> WavVAE.encode -> ZFlowAE.encode -> ZFlowAE.decode -> WavVAE.decode -> audio_hat
+ """
+ wav_mono = wav_mono.to(device)
+ # WavVAE expects (B, C, T); assume C=1
+ x = wav_mono.unsqueeze(0) # (1, 1, T)
+ # Encode to high-framerate latents
+ z = wavvae.encode(x)
+ # Compress latents
+ y = zflowae.encode(z)
+ # Decompress
+ z_hat = zflowae.decode(y, num_steps=decode_params.num_steps, cfg=decode_params.cfg)
+ # Decode to waveform
+ wav_hat = wavvae.decode(z_hat) # (1, 1, T)
+ # Return mono 1D
+ return wav_hat.squeeze(0).squeeze(0).detach()
+
+
+def load_model_pair(spec: ModelPairSpec, device: str):
+ wavvae = WavVAE.from_pretrained_local(spec.wavvae_dir).to(device)
+ zflowae = ZFlowAutoEncoder.from_pretrained_local(spec.zflowae_dir).to(device)
+ # try to get sampling rate from WavVAE
+ target_sr = getattr(wavvae, "sampling_rate", None)
+ if target_sr is None:
+ # reasonable fallback
+ target_sr = 24000
+ return wavvae, zflowae, int(target_sr)
+
+
+def parse_manifest(path: str) -> List[ModelPairSpec]:
+ """
+ Manifest format (JSON list):
+ [
+ {
+ "name": "zdim32x8",
+ "wavvae": "/path/to/WavVAE_framerate100_zdim32/",
+ "zflowae": "/path/to/ZFlowAutoEncoder_stride4_zdim32_vae8_.../",
+ "decode": {"num_steps": 10, "cfg": 2.0}
+ }
+ ]
+ """
+ with open(path, "r", encoding="utf-8") as f:
+ raw = json.load(f)
+ out: List[ModelPairSpec] = []
+ for item in raw:
+ name = item["name"]
+ wavvae_dir = item["wavvae"]
+ zflowae_dir = item["zflowae"]
+ d = item.get("decode", {})
+ out.append(
+ ModelPairSpec(
+ name=name,
+ wavvae_dir=wavvae_dir,
+ zflowae_dir=zflowae_dir,
+ decode=DecodeParams(
+ num_steps=int(d.get("num_steps", 10)),
+ cfg=float(d.get("cfg", 2.0)),
+ ),
+ )
+ )
+ return out
+
+
+# -------------------------
+# HTML generation
+# -------------------------
+
+
+def html_escape(s: str) -> str:
+ return (
+ s.replace("&", "&")
+ .replace("<", "<")
+ .replace(">", ">")
+ .replace('"', """)
+ .replace("'", "'")
+ )
+
+
+def make_html(
+ output_dir: Path,
+ audio_files: List[Path],
+ models: List[ModelPairSpec],
+ sr_by_model: Dict[str, int],
+ wavvae_cfg: Dict[str, Dict[str, Any]],
+ zflow_cfg: Dict[str, Dict[str, Any]],
+) -> str:
+ """
+ Build a static HTML page with a table:
+ Row = input audio file
+ Col 1 = Original
+ Col 2..N = each model reconstruction
+ Also shows minimal model config info above the table.
+ """
+
+ def player(src_rel: str, controls: bool = True) -> str:
+ return f''
+
+ # Model cards
+ model_cards = []
+ for spec in models:
+ wcfg = wavvae_cfg.get(spec.name, {})
+ zcfg = zflow_cfg.get(spec.name, {})
+ w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[
+ :1200
+ ]
+ z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[
+ :1200
+ ]
+ card = f"""
+
+
{html_escape(spec.name)}
+
Sample rate: {sr_by_model.get(spec.name, "N/A")} Hz
+
+ WavVAE config
+ {html_escape(w_short)}
+
+
+ ZFlowAE config
+ {html_escape(z_short)}
+
+
Decode: num_steps={spec.decode.num_steps}, cfg={spec.decode.cfg}
+
+ """
+ model_cards.append(card)
+
+ # Table header
+ th = "Input | Original | " + "".join(
+ f"{html_escape(m.name)} | " for m in models
+ )
+
+ # Rows
+ rows = []
+ for af in audio_files:
+ base = af.stem
+ orig_rel = f"original/{html_escape(af.name)}"
+ tds = [f"{html_escape(base)} | ", f"{player(orig_rel)} | "]
+ for m in models:
+ rec_rel = f"recon/{html_escape(m.name)}/{html_escape(base)}.wav"
+ tds.append(f"{player(rec_rel)} | ")
+ rows.append("" + "".join(tds) + "
")
+
+ # Simple CSS to keep it clean
+ css = """
+ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; }
+ h1 { margin-bottom: 0.2rem; }
+ .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; }
+ .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; }
+ table { border-collapse: collapse; width: 100%; }
+ th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; }
+ th { background: #fafafa; position: sticky; top: 0; }
+ audio { width: 260px; }
+ """
+
+ html = f"""
+
+
+
+ Codec Comparison
+
+
+
+ Codec Comparison
+ This page compares reconstructions across model checkpoints. Click play in each cell.
+
+ Models
+
+ {"".join(model_cards)}
+
+
+ Audio
+
+ {th}
+
+ {"".join(rows)}
+
+
+
+
+"""
+ out = output_dir / "index.html"
+ out.write_text(html, encoding="utf-8")
+ return str(out)
+
+
+# -------------------------
+# Main
+# -------------------------
+
+
+def main():
+ p = argparse.ArgumentParser(
+ description="Compare Z-Codec configurations and generate a static HTML page."
+ )
+ p.add_argument(
+ "--manifest",
+ type=str,
+ required=True,
+ help="JSON file listing model pairs. See docstring in parse_manifest().",
+ )
+ p.add_argument(
+ "--audio", type=str, nargs="+", required=True, help="List of input audio files."
+ )
+ p.add_argument(
+ "--out",
+ type=str,
+ default="codec_compare_out",
+ help="Output directory for reconstructions and HTML.",
+ )
+ p.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="Device to run inference on (cuda or cpu).",
+ )
+ p.add_argument(
+ "--force",
+ action="store_true",
+ help="Recompute even if target wav already exists.",
+ )
+ args = p.parse_args()
+
+ device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"
+ out_dir = Path(args.out)
+ orig_dir = out_dir / "original"
+ recon_dir = out_dir / "recon"
+ orig_dir.mkdir(parents=True, exist_ok=True)
+ recon_dir.mkdir(parents=True, exist_ok=True)
+
+ # Parse models
+ specs = parse_manifest(args.manifest)
+ if not specs:
+ print("No models in manifest.", file=sys.stderr)
+ sys.exit(1)
+
+ # Load models
+ loaded: Dict[str, Dict[str, Any]] = {}
+ sr_by_model: Dict[str, int] = {}
+ wavvae_cfg: Dict[str, Dict[str, Any]] = {}
+ zflow_cfg: Dict[str, Dict[str, Any]] = {}
+
+ for spec in specs:
+ print(f"[Load] {spec.name}")
+ wavvae, zflowae, target_sr = load_model_pair(spec, device)
+ loaded[spec.name] = {"wavvae": wavvae, "zflowae": zflowae, "sr": target_sr}
+ sr_by_model[spec.name] = target_sr
+ wavvae_cfg[spec.name] = read_config_any(spec.wavvae_dir)
+ zflow_cfg[spec.name] = read_config_any(spec.zflowae_dir)
+
+ # Process audio files
+ audio_files = [Path(a) for a in args.audio]
+ for af in audio_files:
+ if not af.exists():
+ print(f"[Skip] Missing: {af}", file=sys.stderr)
+ continue
+
+ # copy original (resampled per model? We'll store original as-is)
+ # Just place the original file for direct playback
+ # If it's not wav, we still copy a WAV version for compatibility.
+ # But simplest: if not wav, we re-save as wav 16-bit for the page.
+ out_orig = orig_dir / af.name
+ if args.force or not out_orig.exists():
+ # Load and resave as wav to ensure browser-compat
+ wav, sr = ta_load(str(af))
+ # make it mono for fair listening
+ wav_mono, sr = ensure_mono_and_resample(wav, sr, sr)
+ save_wav(out_orig.with_suffix(".wav"), wav_mono, sr)
+ # keep the name consistent in the HTML (use .wav)
+ af = af.with_suffix(".wav")
+ # rename saved file to matched name
+ if out_orig.suffix != ".wav":
+ # Clean: ensure HTML references the .wav filename
+ out_orig = out_orig.with_suffix(".wav")
+
+ # For each model, run full pipeline and save
+ base = af.stem
+ # Re-load from disk to ensure consistent start-point (original .wav in out folder)
+ wav0, sr0 = ta_load(str(out_orig if out_orig.exists() else orig_dir / af.name))
+ # Make mono only once; resample per-model to each target SR
+ if wav0.size(0) > 1:
+ wav0 = wav0.mean(dim=0, keepdim=True)
+
+ for spec in specs:
+ mname = spec.name
+ target_sr = sr_by_model[mname]
+ # resample to model's SR
+ if sr0 != target_sr:
+ wav_mono = ta_resample(wav0, sr0, target_sr)
+ else:
+ wav_mono = wav0
+
+ # reconstruct
+ out_path = recon_dir / mname / f"{sanitize_name(base)}.wav"
+ if args.force or not out_path.exists():
+ print(f"[Reconstruct] {mname} ← {base}")
+ wavvae = loaded[mname]["wavvae"]
+ zflowae = loaded[mname]["zflowae"]
+ wav_hat = reconstruct_full_pipeline(
+ wav_mono, target_sr, wavvae, zflowae, spec.decode, device
+ )
+ save_wav(out_path, wav_hat.unsqueeze(0), target_sr)
+
+ # Build HTML
+ # Rebuild the list of files actually present in original/ (use .wav names)
+ actual_audio = sorted([p for p in (orig_dir).glob("*.wav")])
+ html_path = make_html(
+ out_dir,
+ actual_audio,
+ specs,
+ sr_by_model,
+ wavvae_cfg,
+ zflow_cfg,
+ )
+ print(f"\nDone. Open: {html_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/compare_wavvae.py b/codec/scripts/compare_wavvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8af4753ca1ed4f3c58575a7c5bbe4a35aa4aaf4
--- /dev/null
+++ b/codec/scripts/compare_wavvae.py
@@ -0,0 +1,264 @@
+#!/usr/bin/env python3
+import argparse
+import json
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torchaudio
+from torchaudio import load as ta_load
+from torchaudio.functional import resample as ta_resample
+from zcodec.models import WavVAE
+
+# -------------------------
+# Data structures
+# -------------------------
+
+
+@dataclass
+class WavVaeSpec:
+ name: str
+ wavvae_dir: str
+
+
+# -------------------------
+# Utilities
+# -------------------------
+
+
+def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]:
+ if path.is_file():
+ try:
+ return json.load(path.open("r", encoding="utf-8"))
+ except Exception:
+ return None
+ return None
+
+
+def read_config_any(checkpoint_dir: str) -> Dict[str, Any]:
+ cand = [
+ Path(checkpoint_dir) / "config.json",
+ Path(checkpoint_dir) / "model_config.json",
+ Path(checkpoint_dir) / "config.yaml", # shown as path only
+ ]
+ for p in cand:
+ if p.exists():
+ if p.suffix == ".json":
+ j = load_json_if_exists(p)
+ if j is not None:
+ return j
+ else:
+ return {"_config_file": str(p)}
+ return {}
+
+
+def sanitize_name(s: str) -> str:
+ return "".join(c if c.isalnum() or c in "-_." else "_" for c in s)
+
+
+def ensure_mono_and_resample(
+ wav: torch.Tensor, sr: int, target_sr: int
+) -> Tuple[torch.Tensor, int]:
+ if wav.ndim != 2:
+ raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}")
+ if wav.size(0) > 1:
+ wav = wav.mean(dim=0, keepdim=True)
+ if sr != target_sr:
+ wav = ta_resample(wav, sr, target_sr)
+ sr = target_sr
+ return wav.to(torch.float32), sr
+
+
+def save_wav(path: Path, wav: torch.Tensor, sr: int):
+ path.parent.mkdir(parents=True, exist_ok=True)
+ if wav.ndim == 1:
+ wav = wav.unsqueeze(0)
+ wav = wav.clamp(-1, 1).contiguous().cpu()
+ torchaudio.save(
+ str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16
+ )
+
+
+def read_audio_manifest(txt_path: str) -> List[Path]:
+ lines = Path(txt_path).read_text(encoding="utf-8").splitlines()
+ files = [
+ Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#")
+ ]
+ return files
+
+
+def html_escape(s: str) -> str:
+ return (
+ s.replace("&", "&")
+ .replace("<", "<")
+ .replace(">", ">")
+ .replace('"', """)
+ .replace("'", "'")
+ )
+
+
+def make_html(
+ output_dir: Path,
+ audio_files: List[Path],
+ specs: List[WavVaeSpec],
+ sr_by_model: Dict[str, int],
+ wavvae_cfg: Dict[str, Dict[str, Any]],
+) -> str:
+ def player(src_rel: str) -> str:
+ return f''
+
+ # cards
+ cards = []
+ for s in specs:
+ cfg = wavvae_cfg.get(s.name, {})
+ cfg_short = json.dumps(cfg if cfg else {"_": "no JSON config found"}, indent=2)[
+ :1200
+ ]
+ card = f"""
+
+
{html_escape(s.name)}
+
Sample rate: {sr_by_model.get(s.name, "N/A")} Hz
+
WavVAE config
{html_escape(cfg_short)}
+
+ """
+ cards.append(card)
+
+ css = """
+ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; }
+ .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; }
+ .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; }
+ table { border-collapse: collapse; width: 100%; }
+ th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; }
+ th { background: #fafafa; position: sticky; top: 0; }
+ audio { width: 260px; }
+ """
+
+ th = "Input | Original | " + "".join(
+ f"{html_escape(s.name)} | " for s in specs
+ )
+ rows = []
+ for af in audio_files:
+ base = af.stem
+ orig_rel = f"original/{html_escape(af.name)}"
+ tds = [f"{html_escape(base)} | ", f"{player(orig_rel)} | "]
+ for s in specs:
+ rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav"
+ tds.append(f"{player(rec_rel)} | ")
+ rows.append("" + "".join(tds) + "
")
+
+ html = f"""
+
+ WavVAE Comparison
+
+ WavVAE Comparison
+ {"".join(cards)}
+
+ {th}
+ {"".join(rows)}
+
+
+
+"""
+ out = output_dir / "index.html"
+ out.write_text(html, encoding="utf-8")
+ return str(out)
+
+
+# -------------------------
+# Core
+# -------------------------
+
+
+@torch.inference_mode()
+def reconstruct_wavvae(
+ wav_mono: torch.Tensor, wavvae: WavVAE, device: str
+) -> torch.Tensor:
+ x = wav_mono.to(device) # (1,T)
+ z = wavvae.encode(x)
+ wav_hat = wavvae.decode(z) # (1,1,T)
+ return wav_hat.squeeze(0).squeeze(0).detach()
+
+
+def parse_models_manifest(path: str) -> List[WavVaeSpec]:
+ """
+ JSON list of:
+ {"name": "...", "wavvae": "/path/to/WavVAE_dir"}
+ """
+ raw = json.loads(Path(path).read_text(encoding="utf-8"))
+ specs = []
+ for it in raw:
+ specs.append(WavVaeSpec(name=it["name"], wavvae_dir=it["wavvae"]))
+ return specs
+
+
+def main():
+ ap = argparse.ArgumentParser(
+ description="Compare WavVAE checkpoints and generate a static HTML page."
+ )
+ ap.add_argument("--models", required=True, help="JSON manifest of WavVAE models.")
+ ap.add_argument(
+ "--audio_manifest", required=True, help="TXT file: one audio path per line."
+ )
+ ap.add_argument("--out", default="compare_wavvae_out")
+ ap.add_argument("--device", default="cuda")
+ ap.add_argument("--force", action="store_true")
+ args = ap.parse_args()
+
+ device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"
+ out_dir = Path(args.out)
+ (out_dir / "original").mkdir(parents=True, exist_ok=True)
+ recon_dir = out_dir / "recon"
+ recon_dir.mkdir(parents=True, exist_ok=True)
+
+ specs = parse_models_manifest(args.models)
+ if not specs:
+ print("No models.", file=sys.stderr)
+ sys.exit(1)
+
+ # load models
+ wavvae_by_name: Dict[str, WavVAE] = {}
+ sr_by_model: Dict[str, int] = {}
+ wavvae_cfg: Dict[str, Dict[str, Any]] = {}
+ for s in specs:
+ print(f"[Load] {s.name}")
+ w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device)
+ wavvae_by_name[s.name] = w
+ sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000))
+ wavvae_cfg[s.name] = read_config_any(s.wavvae_dir)
+
+ audio_paths = read_audio_manifest(args.audio_manifest)
+ # normalize originals to wav+mono (browser-friendly); keep native sr for original column
+ actual_audio = []
+ for ap in audio_paths:
+ if not ap.exists():
+ print(f"[Skip missing] {ap}", file=sys.stderr)
+ continue
+ wav, sr = ta_load(str(ap))
+ wav_mono, sr = ensure_mono_and_resample(wav, sr, sr)
+ out_orig = out_dir / "original" / (ap.stem + ".wav")
+ if args.force or not out_orig.exists():
+ save_wav(out_orig, wav_mono, sr)
+ actual_audio.append(out_orig)
+
+ # recon per model
+ for out_orig in actual_audio:
+ wav0, sr0 = ta_load(str(out_orig))
+ if wav0.size(0) > 1:
+ wav0 = wav0.mean(dim=0, keepdim=True)
+ for s in specs:
+ target_sr = sr_by_model[s.name]
+ wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0
+ out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav"
+ if args.force or not out_path.exists():
+ print(f"[Reconstruct] {s.name} ← {out_orig.name}")
+ wav_hat = reconstruct_wavvae(wav_in, wavvae_by_name[s.name], device)
+ save_wav(out_path, wav_hat, target_sr)
+
+ html_path = make_html(out_dir, actual_audio, specs, sr_by_model, wavvae_cfg)
+ print(f"Done. Open: {html_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/compare_zcodec.py b/codec/scripts/compare_zcodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..26042fde831a2ab969d00b870c05a51fec26746c
--- /dev/null
+++ b/codec/scripts/compare_zcodec.py
@@ -0,0 +1,312 @@
+#!/usr/bin/env python3
+import argparse
+import json
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torchaudio
+from torchaudio import load as ta_load
+from torchaudio.functional import resample as ta_resample
+from zcodec.models import WavVAE, ZFlowAutoEncoder
+
+# -------------------------
+# Data structures
+# -------------------------
+
+
+@dataclass
+class DecodeParams:
+ num_steps: int = 10
+ cfg: float = 2.0
+
+
+@dataclass
+class StackSpec:
+ name: str
+ wavvae_dir: str
+ zflowae_dir: str
+ decode: DecodeParams
+
+
+# -------------------------
+# Utilities (same helpers)
+# -------------------------
+
+
+def load_json_if_exists(path: Path):
+ if path.is_file():
+ try:
+ return json.load(path.open("r", encoding="utf-8"))
+ except Exception:
+ return None
+ return None
+
+
+def read_config_any(checkpoint_dir: str) -> Dict[str, Any]:
+ cand = [
+ Path(checkpoint_dir) / "config.json",
+ Path(checkpoint_dir) / "model_config.json",
+ Path(checkpoint_dir) / "config.yaml",
+ ]
+ for p in cand:
+ if p.exists():
+ if p.suffix == ".json":
+ j = load_json_if_exists(p)
+ if j is not None:
+ return j
+ else:
+ return {"_config_file": str(p)}
+ return {}
+
+
+def sanitize_name(s: str) -> str:
+ return "".join(c if c.isalnum() or c in "-_." else "_" for c in s)
+
+
+def ensure_mono_and_resample(
+ wav: torch.Tensor, sr: int, target_sr: int
+) -> Tuple[torch.Tensor, int]:
+ if wav.ndim != 2:
+ raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}")
+ if wav.size(0) > 1:
+ wav = wav.mean(dim=0, keepdim=True)
+ if sr != target_sr:
+ wav = ta_resample(wav, sr, target_sr)
+ sr = target_sr
+ return wav.to(torch.float32), sr
+
+
+def save_wav(path: Path, wav: torch.Tensor, sr: int):
+ path.parent.mkdir(parents=True, exist_ok=True)
+ if wav.ndim == 1:
+ wav = wav.unsqueeze(0)
+ wav = wav.clamp(-1, 1).contiguous().cpu()
+ torchaudio.save(
+ str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16
+ )
+
+
+def read_audio_manifest(txt_path: str) -> List[Path]:
+ lines = Path(txt_path).read_text(encoding="utf-8").splitlines()
+ return [
+ Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#")
+ ]
+
+
+def html_escape(s: str) -> str:
+ return (
+ s.replace("&", "&")
+ .replace("<", "<")
+ .replace(">", ">")
+ .replace('"', """)
+ .replace("'", "'")
+ )
+
+
+def make_html(
+ output_dir: Path,
+ audio_files: List[Path],
+ specs: List[StackSpec],
+ sr_by_model: Dict[str, int],
+ wavvae_cfg: Dict[str, Dict[str, Any]],
+ zflow_cfg: Dict[str, Dict[str, Any]],
+) -> str:
+ def player(src_rel: str) -> str:
+ return f''
+
+ cards = []
+ for s in specs:
+ wcfg = wavvae_cfg.get(s.name, {})
+ zcfg = zflow_cfg.get(s.name, {})
+ w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[
+ :1200
+ ]
+ z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[
+ :1200
+ ]
+ card = f"""
+
+
{html_escape(s.name)}
+
Sample rate: {sr_by_model.get(s.name, "N/A")} Hz
+
Decode: steps={s.decode.num_steps}, cfg={s.decode.cfg}
+
WavVAE config
{html_escape(w_short)}
+
ZFlowAE config
{html_escape(z_short)}
+
+ """
+ cards.append(card)
+
+ css = """
+ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; }
+ .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; }
+ .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; }
+ table { border-collapse: collapse; width: 100%; }
+ th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; }
+ th { background: #fafafa; position: sticky; top: 0; }
+ audio { width: 260px; }
+ """
+
+ th = "Input | Original | " + "".join(
+ f"{html_escape(s.name)} | " for s in specs
+ )
+ rows = []
+ for af in audio_files:
+ base = af.stem
+ orig_rel = f"original/{html_escape(af.name)}"
+ tds = [f"{html_escape(base)} | ", f"{player(orig_rel)} | "]
+ for s in specs:
+ rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav"
+ tds.append(f"{player(rec_rel)} | ")
+ rows.append("" + "".join(tds) + "
")
+
+ html = f"""
+
+ Stacked Codec Comparison
+
+ WavVAE + ZFlowAE Comparison
+ {"".join(cards)}
+
+ {th}
+ {"".join(rows)}
+
+
+
+"""
+ out = output_dir / "index.html"
+ out.write_text(html, encoding="utf-8")
+ return str(out)
+
+
+# -------------------------
+# Core
+# -------------------------
+
+
+@torch.inference_mode()
+def reconstruct_stack(
+ wav_mono: torch.Tensor,
+ wavvae: WavVAE,
+ zflow: ZFlowAutoEncoder,
+ steps: int,
+ cfg: float,
+ device: str,
+) -> torch.Tensor:
+ x = wav_mono.to(device) # (1,T)
+ z = wavvae.encode(x) # high-framerate latents
+ y, _ = zflow.encode(z) # compressed latents
+ z_hat = zflow.decode(y, num_steps=steps, cfg=cfg)
+ wav_hat = wavvae.decode(z_hat) # (1,1,T)
+ return wav_hat.squeeze(0).squeeze(0).detach()
+
+
+def parse_models_manifest(path: str) -> List[StackSpec]:
+ """
+ JSON list of:
+ {
+ "name": "...",
+ "wavvae": "/path/to/WavVAE_dir",
+ "zflowae": "/path/to/ZFlowAE_dir",
+ "decode": {"num_steps": 10, "cfg": 2.0}
+ }
+ """
+ raw = json.loads(Path(path).read_text(encoding="utf-8"))
+ specs = []
+ for it in raw:
+ d = it.get("decode", {})
+ specs.append(
+ StackSpec(
+ name=it["name"],
+ wavvae_dir=it["wavvae"],
+ zflowae_dir=it["zflowae"],
+ decode=DecodeParams(
+ num_steps=int(d.get("num_steps", 10)), cfg=float(d.get("cfg", 2.0))
+ ),
+ )
+ )
+ return specs
+
+
+def main():
+ ap = argparse.ArgumentParser(
+ description="Compare WavVAE+ZFlowAE stacks and generate a static HTML page."
+ )
+ ap.add_argument("--models", required=True, help="JSON manifest of stacks.")
+ ap.add_argument(
+ "--audio_manifest", required=True, help="TXT file: one audio path per line."
+ )
+ ap.add_argument("--out", default="compare_stack_out")
+ ap.add_argument("--device", default="cuda")
+ ap.add_argument("--force", action="store_true")
+ args = ap.parse_args()
+
+ device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"
+ out_dir = Path(args.out)
+ (out_dir / "original").mkdir(parents=True, exist_ok=True)
+ recon_dir = out_dir / "recon"
+ recon_dir.mkdir(parents=True, exist_ok=True)
+
+ specs = parse_models_manifest(args.models)
+ if not specs:
+ print("No models.", file=sys.stderr)
+ sys.exit(1)
+
+ # load models
+ wavvae_by_name: Dict[str, WavVAE] = {}
+ zflow_by_name: Dict[str, ZFlowAutoEncoder] = {}
+ sr_by_model: Dict[str, int] = {}
+ wavvae_cfg: Dict[str, Dict[str, Any]] = {}
+ zflow_cfg: Dict[str, Dict[str, Any]] = {}
+ for s in specs:
+ print(f"[Load] {s.name}")
+ w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device)
+ z = ZFlowAutoEncoder.from_pretrained_local(s.zflowae_dir).to(device)
+ wavvae_by_name[s.name] = w
+ zflow_by_name[s.name] = z
+ sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000))
+ wavvae_cfg[s.name] = read_config_any(s.wavvae_dir)
+ zflow_cfg[s.name] = read_config_any(s.zflowae_dir)
+
+ audio_paths = read_audio_manifest(args.audio_manifest)
+
+ actual_audio = []
+ for ap in audio_paths:
+ if not ap.exists():
+ print(f"[Skip missing] {ap}", file=sys.stderr)
+ continue
+ wav, sr = ta_load(str(ap))
+ wav_mono, sr = ensure_mono_and_resample(wav, sr, sr)
+ out_orig = out_dir / "original" / (ap.stem + ".wav")
+ if args.force or not out_orig.exists():
+ save_wav(out_orig, wav_mono, sr)
+ actual_audio.append(out_orig)
+
+ for out_orig in actual_audio:
+ wav0, sr0 = ta_load(str(out_orig))
+ if wav0.size(0) > 1:
+ wav0 = wav0.mean(dim=0, keepdim=True)
+ for s in specs:
+ target_sr = sr_by_model[s.name]
+ wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0
+ out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav"
+ if args.force or not out_path.exists():
+ print(f"[Reconstruct] {s.name} ← {out_orig.name}")
+ wav_hat = reconstruct_stack(
+ wav_in,
+ wavvae_by_name[s.name],
+ zflow_by_name[s.name],
+ s.decode.num_steps,
+ s.decode.cfg,
+ device,
+ )
+ save_wav(out_path, wav_hat, target_sr)
+
+ html_path = make_html(
+ out_dir, actual_audio, specs, sr_by_model, wavvae_cfg, zflow_cfg
+ )
+ print(f"Done. Open: {html_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/compute_stats.py b/codec/scripts/compute_stats.py
new file mode 100644
index 0000000000000000000000000000000000000000..2de69b5ced2435ac0a3d596de617f6ea8f145b82
--- /dev/null
+++ b/codec/scripts/compute_stats.py
@@ -0,0 +1,76 @@
+import argparse
+import random
+
+import torch
+from safetensors.torch import safe_open, save_file
+from tqdm import tqdm
+
+
+def load_tensor(path: str, key: str = "embedding") -> torch.Tensor:
+ with safe_open(path, framework="pt", device="cpu") as f:
+ return f.get_tensor(key)
+
+
+def compute_global_stats(file_list, key="embedding", length_weighted=True):
+ sum_all = None
+ sum_sq_all = None
+ count_all = 0
+
+ for path in tqdm(file_list, desc="Computing stats"):
+ tensor = load_tensor(path, key) # shape: [B, T, D]
+ flat = tensor.reshape(-1, tensor.shape[-1]) # [B*T, D]
+
+ sum_ = flat.sum(dim=0) # [D]
+ sum_sq = (flat**2).sum(dim=0) # [D]
+ count = flat.shape[0] # B*T
+
+ if sum_all is None:
+ sum_all = sum_
+ sum_sq_all = sum_sq
+ else:
+ sum_all += sum_
+ sum_sq_all += sum_sq
+
+ count_all += count
+
+ mean = sum_all / count_all
+ var = sum_sq_all / count_all - mean**2
+ std = torch.sqrt(torch.clamp(var, min=1e-8))
+
+ return mean, std
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "filelist", type=str, help="Text file with list of safetensors paths"
+ )
+ parser.add_argument("output", type=str, help="Path to output stats.safetensors")
+ parser.add_argument(
+ "--key", type=str, default="audio_z", help="Key of tensor in safetensors file"
+ )
+ parser.add_argument(
+ "--max-files", type=int, default=None, help="Max number of files to process"
+ )
+ parser.add_argument(
+ "--seed", type=int, default=42, help="Random seed for shuffling"
+ )
+
+ args = parser.parse_args()
+
+ with open(args.filelist) as f:
+ files = [line.strip() for line in f if line.strip()]
+
+ if args.max_files:
+ random.seed(args.seed)
+ files = random.sample(files, k=min(args.max_files, len(files)))
+
+ mean, std = compute_global_stats(files, key=args.key)
+
+ save_file({"mean": mean, "std": std}, args.output)
+ print(f"✅ Saved to {args.output}")
+ print("Example mean/std:", mean[:5], std[:5])
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/compute_wer.py b/codec/scripts/compute_wer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d1f8398d08892c17e461450505979649393bdd8
--- /dev/null
+++ b/codec/scripts/compute_wer.py
@@ -0,0 +1,48 @@
+import argparse
+import json
+import string
+
+from jiwer import wer
+
+
+def normalize_text(text: str) -> str:
+ """
+ Lowercase and remove punctuation from a string.
+
+ Args:
+ text (str): Input string
+
+ Returns:
+ str: Normalized string
+ """
+ # Lowercase
+ text = text.lower()
+ # Remove punctuation
+ text = text.translate(str.maketrans("", "", string.punctuation))
+ return text
+
+
+def load_transcripts(jsonl_path):
+ originals = []
+ reconstructions = []
+ with open(jsonl_path, "r", encoding="utf-8") as f:
+ for line in f:
+ data = json.loads(line)
+ originals.append(data["original_text"])
+ reconstructions.append(data["reconstructed_text"])
+ return originals, reconstructions
+
+
+def main(args):
+ originals, reconstructions = map(normalize_text, load_transcripts(args.jsonl))
+ score = wer(originals, reconstructions)
+ print(f"WER: {score:.3%}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--jsonl", type=str, required=True, help="Path to the transcript JSONL file"
+ )
+ args = parser.parse_args()
+ main(args)
diff --git a/codec/scripts/compute_wer_from_refs.py b/codec/scripts/compute_wer_from_refs.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a4304a0134a1ce81ab32a7376332f7c260c3efb
--- /dev/null
+++ b/codec/scripts/compute_wer_from_refs.py
@@ -0,0 +1,64 @@
+import argparse
+import json
+import string
+from pathlib import Path
+
+from jiwer import cer, wer
+
+
+def normalize_text(text: str) -> str:
+ """
+ Lowercase and remove punctuation from a string.
+
+ Args:
+ text (str): Input string
+
+ Returns:
+ str: Normalized string
+ """
+ # Lowercase
+ text = text.lower()
+ # Remove punctuation
+ text = text.translate(str.maketrans("", "", string.punctuation))
+ return text
+
+
+def load_jsonl_dict(path):
+ transcripts = {}
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f:
+ data = json.loads(line)
+ transcripts[Path(data["file"]).name] = data["transcript"]
+ return transcripts
+
+
+def main(args):
+ ref_dict = load_jsonl_dict(args.reference)
+ hyp_dict = load_jsonl_dict(args.hypothesis)
+
+ common_files = set(ref_dict.keys()) & set(hyp_dict.keys())
+
+ if not common_files:
+ print("No common files between reference and hypothesis.")
+ return
+
+ refs = [normalize_text(ref_dict[f]) for f in sorted(common_files)]
+ hyps = [normalize_text(hyp_dict[f]) for f in sorted(common_files)]
+
+ cer_score = cer(refs, hyps)
+ wer_score = wer(refs, hyps)
+ print(f"CER: {cer_score:.3%}")
+ print(f"WER: {wer_score:.3%}")
+ print(f"Evaluated on {len(common_files)} files.")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--reference", type=str, required=True, help="Path to reference JSONL"
+ )
+ parser.add_argument(
+ "--hypothesis", type=str, required=True, help="Path to hypothesis JSONL"
+ )
+ args = parser.parse_args()
+ main(args)
diff --git a/codec/scripts/download_expresso.py b/codec/scripts/download_expresso.py
new file mode 100644
index 0000000000000000000000000000000000000000..08c3683e7732f3b62629ffaced648966ae28a28d
--- /dev/null
+++ b/codec/scripts/download_expresso.py
@@ -0,0 +1,10 @@
+import soundfile as sf
+
+from datasets import load_dataset
+
+dataset = load_dataset("ylacombe/expresso", split="train")
+print(dataset)
+for i, x in enumerate(dataset):
+ audio = x["audio"]
+ wav, sr = audio["array"], audio["sampling_rate"]
+ sf.write(f"expresso/org/{i}.wav", wav, sr)
diff --git a/codec/scripts/download_gigaspeech.py b/codec/scripts/download_gigaspeech.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec3e7dbef9107f03e401687a5ad3cfb532146f57
--- /dev/null
+++ b/codec/scripts/download_gigaspeech.py
@@ -0,0 +1,14 @@
+from random import sample
+
+import soundfile as sf
+from datasets import load_dataset
+
+# dataset = load_dataset("keithito/lj_speech", split="train")
+#dataset = load_dataset("parler-tts/mls_eng", split="train")
+dataset = load_dataset("speechcolab/gigaspeech", "xl", split="train", token=True)
+Is = sample(list(range(len(dataset))), k=100000)
+print(dataset)
+for i, I in enumerate(Is):
+ audio = dataset[I]["audio"]
+ wav, sr = audio["array"], audio["sampling_rate"]
+ sf.write(f"gigaspeech/{I}.wav", wav, sr)
diff --git a/codec/scripts/download_lj.py b/codec/scripts/download_lj.py
new file mode 100644
index 0000000000000000000000000000000000000000..b702a9764e78297294b60f610cf3e2606569bc64
--- /dev/null
+++ b/codec/scripts/download_lj.py
@@ -0,0 +1,9 @@
+import soundfile as sf
+from datasets import load_dataset
+
+dataset = load_dataset("keithito/lj_speech", split="train")
+print(dataset)
+for i, x in enumerate(dataset):
+ audio = x["audio"]
+ wav, sr = audio["array"], audio["sampling_rate"]
+ sf.write(f"ljspeech/{i}.wav", wav, sr)
diff --git a/codec/scripts/download_ltts.py b/codec/scripts/download_ltts.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ad7268a5eceb4554e095607f6831346d3fba3fd
--- /dev/null
+++ b/codec/scripts/download_ltts.py
@@ -0,0 +1,16 @@
+from pathlib import Path
+
+import soundfile as sf
+
+from datasets import load_dataset
+
+dataset = load_dataset("mythicinfinity/libritts", "clean")
+for split in dataset.keys():
+ Path(f"libritts/{split}").mkdir(exist_ok=True)
+ for i, x in enumerate(dataset[split]):
+ # audio = x["audio"]
+ text = x["text_normalized"]
+ # wav, sr = audio["array"], audio["sampling_rate"]
+ # sf.write(f"libritts/{split}/{i}.wav", wav, sr)
+ with open(f"libritts/{split}/{i}.txt", "w") as f:
+ f.write(text)
diff --git a/codec/scripts/download_mlseng10k.py b/codec/scripts/download_mlseng10k.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b25809a14587330fef9b6cba7d833fab77e3e52
--- /dev/null
+++ b/codec/scripts/download_mlseng10k.py
@@ -0,0 +1,13 @@
+from random import sample
+
+import soundfile as sf
+from datasets import load_dataset
+
+# dataset = load_dataset("keithito/lj_speech", split="train")
+dataset = load_dataset("parler-tts/mls_eng", split="train")
+Is = sample(list(range(len(dataset))), k=100000)
+print(dataset)
+for i, I in enumerate(Is):
+ audio = dataset[I]["audio"]
+ wav, sr = audio["array"], audio["sampling_rate"]
+ sf.write(f"mls10keng/{i}.wav", wav, sr)
diff --git a/codec/scripts/eval_asr.py b/codec/scripts/eval_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dc131579e9580e0cf93c45949cf86f8cfd339c8
--- /dev/null
+++ b/codec/scripts/eval_asr.py
@@ -0,0 +1,100 @@
+import argparse
+import json
+from pathlib import Path
+
+import nemo.collections.asr as nemo_asr
+import torch
+import yaml
+from jiwer import wer
+from torchaudio import load
+from torchaudio.functional import resample
+from tqdm import tqdm
+
+from zcodec.models import WavVAE, ZFlowAutoEncoder
+
+
+def load_config(config_path):
+ with open(config_path, "r") as f:
+ return yaml.safe_load(f)
+
+
+def transcribe(audio: torch.Tensor, asr_model) -> str:
+ audio = audio.cpu().numpy(force=True)
+ with torch.inference_mode():
+ return asr_model.transcribe([audio[0]])[0].text
+
+
+def main(args):
+ config = load_config(args.config)
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Load models
+ wavvae = WavVAE.from_pretrained_local(config["wavvae_ckpt"]).to(device).eval()
+ zflowae = (
+ ZFlowAutoEncoder.from_pretrained_local(config["zflowae_ckpt"]).to(device).eval()
+ )
+
+ # Load ASR model
+ asr_model = nemo_asr.models.ASRModel.from_pretrained(
+ model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2")
+ )
+
+ # Read file list
+ with open(config["file_list"], "r") as f:
+ wav_files = [line.strip() for line in f if line.strip()]
+
+ results = []
+
+ for wav_path in tqdm(wav_files, desc="Processing files"):
+ wav, sr = load(wav_path)
+ wav = resample(wav, orig_freq=sr, new_freq=wavvae.sampling_rate).to(device)
+
+ with torch.inference_mode():
+ # Transcribe original
+ original_text = transcribe(wav, asr_model)
+
+ # Compress and decompress
+ z = wavvae.encode(wav)
+ zz, _ = zflowae.encode(z)
+ z_hat = zflowae.decode(
+ zz, num_steps=config.get("num_steps", 10), cfg=config.get("cfg", 2.0)
+ )
+ wav_hat = wavvae.decode(z_hat)
+
+ # Transcribe reconstructed
+ reconstructed_text = transcribe(wav_hat, asr_model)
+
+ results.append(
+ {
+ "file": wav_path,
+ "original_text": original_text,
+ "reconstructed_text": reconstructed_text,
+ }
+ )
+
+ # Save output
+ out_path = Path(config.get("output_jsonl", "transcripts.jsonl"))
+ with out_path.open("w") as f:
+ for entry in results:
+ f.write(json.dumps(entry, ensure_ascii=False) + "\n")
+
+ print(f"\nSaved {len(results)} transcript pairs to {out_path}")
+
+ # Optionally compute WER
+ if args.compute_wer:
+ original_texts = [r["original_text"] for r in results]
+ reconstructed_texts = [r["reconstructed_text"] for r in results]
+ score = wer(original_texts, reconstructed_texts)
+ print(f"WER: {score:.3%}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
+ parser.add_argument(
+ "--compute_wer", action="store_true", help="Compute WER after decoding"
+ )
+ args = parser.parse_args()
+
+ main(args)
diff --git a/codec/scripts/eval_asr_from_filelist.py b/codec/scripts/eval_asr_from_filelist.py
new file mode 100644
index 0000000000000000000000000000000000000000..8764aa43ba1878c63ab9e370181556c66fed27d3
--- /dev/null
+++ b/codec/scripts/eval_asr_from_filelist.py
@@ -0,0 +1,60 @@
+import argparse
+import json
+from pathlib import Path
+
+import nemo.collections.asr as nemo_asr
+import torch
+import yaml
+from torchaudio import load
+from torchaudio.functional import resample
+from tqdm import tqdm
+
+
+def load_config(config_path):
+ with open(config_path, "r") as f:
+ return yaml.safe_load(f)
+
+
+def transcribe(audio: torch.Tensor, asr_model) -> str:
+ audio = audio.cpu().numpy(force=True)
+ with torch.inference_mode():
+ return asr_model.transcribe([audio[0]])[0].text
+
+
+def main(args):
+ config = load_config(args.config)
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Load ASR model
+ asr_model = nemo_asr.models.ASRModel.from_pretrained(
+ model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2")
+ )
+
+ # Read file list
+ with open(config["file_list"], "r") as f:
+ wav_files = [line.strip() for line in f if line.strip()]
+
+ results = []
+
+ for wav_path in tqdm(wav_files, desc="Transcribing"):
+ wav, sr = load(wav_path)
+ wav = resample(wav, orig_freq=sr, new_freq=16000).to(device)
+
+ transcript = transcribe(wav, asr_model)
+ results.append({"file": wav_path, "transcript": transcript})
+
+ # Save output
+ out_path = Path(config.get("output_jsonl", "asr_transcripts.jsonl"))
+ with out_path.open("w") as f:
+ for entry in results:
+ f.write(json.dumps(entry, ensure_ascii=False) + "\n")
+
+ print(f"\nSaved {len(results)} transcripts to {out_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
+ args = parser.parse_args()
+ main(args)
diff --git a/codec/scripts/eval_asr_only.py b/codec/scripts/eval_asr_only.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a3d9e15b37c0ac0604ab5d461d0f1b560c07d75
--- /dev/null
+++ b/codec/scripts/eval_asr_only.py
@@ -0,0 +1,84 @@
+import argparse
+import json
+from pathlib import Path
+
+import nemo.collections.asr as nemo_asr
+import torch
+import yaml
+from torchaudio import load
+from torchaudio.functional import resample
+from tqdm import tqdm
+
+from zcodec.models import WavVAE, ZFlowAutoEncoder
+
+
+def load_config(config_path):
+ with open(config_path, "r") as f:
+ return yaml.safe_load(f)
+
+
+def transcribe(audio: torch.Tensor, asr_model) -> str:
+ audio = audio.cpu().numpy(force=True)
+ with torch.inference_mode():
+ return asr_model.transcribe([audio[0]])[0].text
+
+
+def main(args):
+ config = load_config(args.config)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Load models
+ wavvae = WavVAE.from_pretrained_local(config["wavvae_ckpt"]).to(device).eval()
+ zflowae = (
+ ZFlowAutoEncoder.from_pretrained_local(config["zflowae_ckpt"]).to(device).eval()
+ )
+
+ # Load ASR model
+ asr_model = nemo_asr.models.ASRModel.from_pretrained(
+ model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2")
+ )
+
+ # Read file list
+ with open(config["file_list"], "r") as f:
+ wav_files = [line.strip() for line in f if line.strip()]
+
+ results = []
+
+ for wav_path in tqdm(wav_files, desc="ASR on reconstructed audio"):
+ wav, sr = load(wav_path)
+ wav = resample(wav, orig_freq=sr, new_freq=wavvae.sampling_rate).to(device)
+
+ with torch.inference_mode():
+ # Compress and decompress
+ z = wavvae.encode(wav)
+ zz, _ = zflowae.encode(z)
+ z_hat = zflowae.decode(
+ zz, num_steps=config.get("num_steps", 10), cfg=config.get("cfg", 2.0)
+ )
+ wav_hat = wavvae.decode(z_hat)
+
+ # Transcribe
+ wav_hat = resample(wav_hat, orig_freq=wavvae.sampling_rate, new_freq=16000)
+ reconstructed_text = transcribe(wav_hat, asr_model)
+
+ results.append(
+ {
+ "file": wav_path,
+ "transcript": reconstructed_text,
+ }
+ )
+
+ # Save output
+ out_path = Path(config.get("output_jsonl", "asr_reconstructed.jsonl"))
+ with out_path.open("w") as f:
+ for entry in results:
+ f.write(json.dumps(entry, ensure_ascii=False) + "\n")
+
+ print(f"\nSaved {len(results)} reconstructed ASR results to {out_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
+ args = parser.parse_args()
+ main(args)
diff --git a/codec/scripts/export_wavvae.py b/codec/scripts/export_wavvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..62454ebea83500efee45fa89f69a4ac511935163
--- /dev/null
+++ b/codec/scripts/export_wavvae.py
@@ -0,0 +1,53 @@
+import argparse
+import hashlib
+from pathlib import Path
+
+from zcodec.trainers import TrainWavVAE
+
+
+def hash_checkpoint_file(path, method="sha256", length=8):
+ h = hashlib.new(method)
+ with open(path, "rb") as f:
+ while chunk := f.read(8192):
+ h.update(chunk)
+ return h.hexdigest()[:length]
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Export WavVAE pretrained checkpoint.")
+ parser.add_argument(
+ "checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)"
+ )
+ parser.add_argument(
+ "--output",
+ type=Path,
+ default=None,
+ help="Optional output directory (default is based on config)",
+ )
+ args = parser.parse_args()
+
+ # Load Lightning module
+ wavvae = TrainWavVAE.load_from_checkpoint(args.checkpoint)
+ config = wavvae.hparams.config
+
+ # Compute framerate and z_dim
+ frame_rate = wavvae.wavvae.frame_rate
+ z_dim = config.latent_dim
+
+ checkpoint_hash = hash_checkpoint_file(args.checkpoint)
+ # Determine output directory
+ if args.output is None:
+ out_dir = Path(
+ f"checkpoints/wavvae/pretrained/WavVAE_framerate{frame_rate}_zdim{z_dim}_{checkpoint_hash}"
+ )
+ else:
+ out_dir = args.output
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ # Save weights and config
+ wavvae.save_model_weights_and_config(str(out_dir))
+ print(f"✅ Exported model to {out_dir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/export_zflowae.py b/codec/scripts/export_zflowae.py
new file mode 100644
index 0000000000000000000000000000000000000000..55f858e265e2e6164170c3651d932cb8d1cc17e0
--- /dev/null
+++ b/codec/scripts/export_zflowae.py
@@ -0,0 +1,63 @@
+import argparse
+import hashlib
+from pathlib import Path
+
+from zcodec.trainers import TrainZFlowAutoEncoder
+
+
+def hash_checkpoint_file(path, method="sha256", length=8):
+ h = hashlib.new(method)
+ with open(path, "rb") as f:
+ while chunk := f.read(8192):
+ h.update(chunk)
+ return h.hexdigest()[:length]
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Export ZFlowAutoEncoder pretrained checkpoint."
+ )
+ parser.add_argument(
+ "checkpoint", type=Path, help="Path to the Lightning checkpoint (.ckpt)"
+ )
+ parser.add_argument(
+ "--output",
+ type=Path,
+ default=None,
+ help="Optional output directory (default is based on config)",
+ )
+ args = parser.parse_args()
+
+ zflowae = TrainZFlowAutoEncoder.load_from_checkpoint(args.checkpoint)
+ checkpoint_hash = hash_checkpoint_file(args.checkpoint)
+ config = zflowae.hparams.config
+ repa = False
+ if hasattr(zflowae.hparams, "ssl_repa_head"):
+ if zflowae.hparams.ssl_repa_head:
+ repa = True
+ stride = config.latent_stride
+ zdim = config.latent_dim
+ ae_factory = config.autoencoder_factory
+ if config.fsq_levels is not None:
+ codebook_size = 1
+ for c in config.fsq_levels:
+ codebook_size *= c
+ type = f"fsq{codebook_size}"
+ elif config.vae:
+ type = f"vae{config.bottleneck_size}"
+
+ if args.output is None:
+ out_dir = Path(
+ f"checkpoints/zflowae/pretrained/ZFlowAutoEncoder_stride{stride}_zdim{zdim}_{type}_{ae_factory}_repa{repa}_{checkpoint_hash}"
+ )
+ else:
+ out_dir = args.output
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ # Save weights and config
+ zflowae.save_model_weights_and_config(str(out_dir))
+ print(f"✅ Exported model to {out_dir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/infer_hubert.py b/codec/scripts/infer_hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..35c89b3b863fb0c0d24e14d2455c3cb74826136b
--- /dev/null
+++ b/codec/scripts/infer_hubert.py
@@ -0,0 +1,101 @@
+import argparse
+from pathlib import Path
+
+import librosa
+import soundfile as sf
+import torch
+from safetensors.torch import save_file
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+from transformers import HubertModel
+
+
+class AudioDataset(Dataset):
+ def __init__(self, file_list, target_sr=16000):
+ self.paths = file_list
+ self.target_sr = target_sr
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, idx):
+ path = self.paths[idx]
+ wav, sr = sf.read(str(path))
+ if sr != self.target_sr:
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)
+ wav = torch.tensor(wav).float().unsqueeze(0) # shape: [1, T]
+ return wav, path
+
+
+@torch.no_grad()
+def encode_batch(model, batch, device, out_dir, keep_layers):
+ wavs, paths = batch
+ for wav, path in zip(wavs, paths):
+ wav = wav.to(device)
+ outputs = model(wav, output_hidden_states=True)
+ hidden_states = outputs.hidden_states # tuple of 13 tensors: [1, T', D]
+ selected = {
+ f"layer_{i}": hs.squeeze(0).cpu()
+ for i, hs in enumerate(hidden_states)
+ if i in keep_layers
+ }
+ out_path = out_dir / (path.stem + ".st")
+ save_file(selected, str(out_path))
+
+
+def parse_layers(layer_str, max_layers):
+ if layer_str.strip().lower() == "all":
+ return set(range(max_layers))
+ return set(int(idx) for idx in layer_str.split(",") if idx.strip().isdigit())
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Infer HuBERT hidden states.")
+ parser.add_argument(
+ "file_list", type=Path, help="Text file with paths to audio files"
+ )
+ parser.add_argument("output_dir", type=Path, help="Directory to save .st files")
+ parser.add_argument(
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
+ )
+ parser.add_argument("--num_workers", type=int, default=2)
+ parser.add_argument(
+ "--layers",
+ type=str,
+ default="all",
+ help="Comma-separated layer indices or 'all'",
+ )
+
+ args = parser.parse_args()
+ device = torch.device(args.device)
+
+ model = HubertModel.from_pretrained("facebook/hubert-base-ls960").to(device).eval()
+ num_layers = (
+ len(model.config.hidden_layers)
+ if hasattr(model.config, "hidden_layers")
+ else 13
+ )
+ keep_layers = parse_layers(args.layers, num_layers)
+
+ with open(args.file_list, "r") as f:
+ paths = [Path(line.strip()) for line in f if line.strip()]
+
+ dataset = AudioDataset(paths)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=1,
+ num_workers=args.num_workers,
+ collate_fn=lambda x: list(zip(*x)),
+ )
+
+ args.output_dir.mkdir(parents=True, exist_ok=True)
+
+ for batch in tqdm(dataloader):
+ try:
+ encode_batch(model, batch, device, args.output_dir, keep_layers)
+ except Exception as e:
+ print(f"❌ Failed on batch: {e}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/infer_wav2vec.py b/codec/scripts/infer_wav2vec.py
new file mode 100644
index 0000000000000000000000000000000000000000..e973efe546c16958b9d3c69360f613ff66b57f6a
--- /dev/null
+++ b/codec/scripts/infer_wav2vec.py
@@ -0,0 +1,100 @@
+import argparse
+from pathlib import Path
+
+import librosa
+import soundfile as sf
+import torch
+from safetensors.torch import save_file
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+from transformers import Wav2Vec2Model
+
+
+class AudioDataset(Dataset):
+ def __init__(self, file_list, target_sr=16000):
+ self.paths = file_list
+ self.target_sr = target_sr
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, idx):
+ path = self.paths[idx]
+ wav, sr = sf.read(str(path))
+ if sr != self.target_sr:
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)
+ wav = torch.tensor(wav).float().unsqueeze(0) # shape: [1, T]
+ return wav, path
+
+
+@torch.no_grad()
+def encode_batch(model, batch, device, out_dir, keep_layers):
+ wavs, paths = batch
+ for wav, path in zip(wavs, paths):
+ wav = wav.to(device)
+ outputs = model(wav, output_hidden_states=True)
+ hidden_states = outputs.hidden_states # tuple of 25 tensors: [1, T', D]
+ selected = {
+ f"layer_{i}": hs.squeeze(0).cpu()
+ for i, hs in enumerate(hidden_states)
+ if i in keep_layers
+ }
+ out_path = out_dir / (path.stem + ".st")
+ save_file(selected, str(out_path))
+
+
+def parse_layers(layer_str):
+ if layer_str.strip().lower() == "all":
+ return set(range(25))
+ return set(int(idx) for idx in layer_str.split(",") if idx.strip().isdigit())
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Infer Wav2Vec2 hidden states.")
+ parser.add_argument(
+ "file_list", type=Path, help="Text file with paths to audio files"
+ )
+ parser.add_argument("output_dir", type=Path, help="Directory to save .st files")
+ parser.add_argument(
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
+ )
+ parser.add_argument("--num_workers", type=int, default=2)
+ parser.add_argument(
+ "--layers",
+ type=str,
+ default="all",
+ help="Comma-separated layer indices or 'all'",
+ )
+
+ args = parser.parse_args()
+ keep_layers = parse_layers(args.layers)
+ device = torch.device(args.device)
+
+ model = (
+ Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53")
+ .to(device)
+ .eval()
+ )
+
+ with open(args.file_list, "r") as f:
+ paths = [Path(line.strip()) for line in f if line.strip()]
+
+ dataset = AudioDataset(paths)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=1,
+ num_workers=args.num_workers,
+ collate_fn=lambda x: list(zip(*x)),
+ )
+
+ args.output_dir.mkdir(parents=True, exist_ok=True)
+
+ for batch in tqdm(dataloader):
+ try:
+ encode_batch(model, batch, device, args.output_dir, keep_layers)
+ except Exception as e:
+ print(f"❌ Failed on batch: {e}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/infer_wavvae.py b/codec/scripts/infer_wavvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..6de6a280e15aaa1549a074cf137369e789ea9348
--- /dev/null
+++ b/codec/scripts/infer_wavvae.py
@@ -0,0 +1,93 @@
+import argparse
+from pathlib import Path
+
+import librosa
+import soundfile as sf
+import torch
+from safetensors.torch import save_file
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+
+from zcodec.models import WavVAE
+
+
+class AudioDataset(Dataset):
+ def __init__(self, file_list, target_sr):
+ self.paths = file_list
+ self.target_sr = target_sr
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, idx):
+ path = self.paths[idx]
+ wav, sr = sf.read(str(path))
+ if sr != self.target_sr:
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)
+ wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T]
+ return wav, path
+
+
+@torch.no_grad()
+def encode_batch(model, batch, device, out_dir):
+ wavs, paths = batch
+ for wav, path in zip(wavs, paths):
+ wav = wav.to(device)
+ latent = model.encode(wav).cpu()
+ out_path = out_dir / (path.stem + ".st")
+ save_file({"audio_z": latent}, str(out_path))
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Batch encode audio files with WavVAE."
+ )
+ parser.add_argument(
+ "file_list",
+ type=Path,
+ help="Text file listing paths to audio files (one per line)",
+ )
+ parser.add_argument(
+ "checkpoint", type=Path, help="Path to WavVAE checkpoint directory"
+ )
+ parser.add_argument(
+ "output_dir", type=Path, help="Directory to save Safetensors latents"
+ )
+ parser.add_argument(
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
+ )
+ parser.add_argument(
+ "--num_workers", type=int, default=3, help="Number of DataLoader workers"
+ )
+
+ args = parser.parse_args()
+ device = torch.device(args.device)
+
+ # Load model
+ wavvae = WavVAE.from_pretrained_local(args.checkpoint)
+ wavvae = wavvae.to(device).eval()
+ target_sr = wavvae.sampling_rate
+
+ # Prepare dataset and dataloader
+ with open(args.file_list, "r") as f:
+ file_paths = [Path(line.strip()) for line in f if line.strip()]
+ dataset = AudioDataset(file_paths, target_sr)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=1,
+ num_workers=args.num_workers,
+ collate_fn=lambda x: list(zip(*x)),
+ )
+
+ args.output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Inference loop
+ for batch in tqdm(dataloader):
+ try:
+ encode_batch(wavvae, batch, device, args.output_dir)
+ except Exception as e:
+ print(f"❌ Batch failed: {e}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/infer_wavvae_audiocite.py b/codec/scripts/infer_wavvae_audiocite.py
new file mode 100644
index 0000000000000000000000000000000000000000..a91f901a149e5b595301019f74702e201c2bdfce
--- /dev/null
+++ b/codec/scripts/infer_wavvae_audiocite.py
@@ -0,0 +1,73 @@
+import argparse
+from pathlib import Path
+
+import librosa
+import soundfile as sf
+import torch
+
+from datasets import load_dataset, load_from_disk
+from zcodec.models import WavVAE
+
+
+def load_and_resample(path, target_sr):
+ wav, sr = sf.read(str(path))
+ if sr != target_sr:
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
+ wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T]
+ return wav
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Encode HF dataset audio with WavVAE using map() (non-batched)."
+ )
+ parser.add_argument("dataset", type=str, help="Path or HF hub ID of dataset")
+ parser.add_argument("path_column", type=str, help="Column name with wav file paths")
+ parser.add_argument(
+ "checkpoint", type=Path, help="Path to WavVAE checkpoint directory"
+ )
+ parser.add_argument(
+ "--split", type=str, default=None, help="Dataset split (if loading from hub)"
+ )
+ parser.add_argument(
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
+ )
+ parser.add_argument(
+ "--num_proc", type=int, default=1, help="Number of processes for map()"
+ )
+ args = parser.parse_args()
+
+ device = torch.device(args.device)
+
+ # Load model
+ wavvae = WavVAE.from_pretrained_local(args.checkpoint).to(device).eval()
+ target_sr = wavvae.sampling_rate
+
+ # Load dataset
+ if Path(args.dataset).exists():
+ ds = load_from_disk(args.dataset)
+ else:
+ ds = load_dataset(args.dataset, split=args.split or "train")
+
+ ds = ds.filter(lambda x: x > 1.0, input_columns="duration")
+
+ # Mapping function (non-batched)
+ @torch.no_grad()
+ def encode_example(example):
+ wav = load_and_resample(example[args.path_column], target_sr).to(device)
+ latent = wavvae.encode(wav).cpu().numpy()
+ example["audio_z"] = latent
+ return example
+
+ # Apply map without batching
+ ds = ds.map(
+ encode_example,
+ num_proc=args.num_proc,
+ )
+
+ # Save dataset with new column
+ ds.save_to_disk(str(Path(args.dataset) + "_with_latents"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/infer_wavvae_hf.py b/codec/scripts/infer_wavvae_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..58415077977d545201c6d42d27e6244ccf12bd9d
--- /dev/null
+++ b/codec/scripts/infer_wavvae_hf.py
@@ -0,0 +1,146 @@
+import argparse
+from pathlib import Path
+
+import librosa
+import soundfile as sf
+import torch
+import torchaudio
+from torchaudio.functional import resample
+from safetensors.torch import save_file
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+
+from datasets import Audio, load_dataset, load_from_disk
+from zcodec.models import WavVAE
+
+
+class AudioDataset(Dataset):
+ def __init__(self, file_list, target_sr):
+ self.paths = file_list
+ self.target_sr = target_sr
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, idx):
+ path = self.paths[idx]
+ wav, sr = sf.read(str(path))
+ if sr != self.target_sr:
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)
+ wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T]
+ return wav, path
+
+
+@torch.no_grad()
+def encode_batch(model, batch, device, out_dir):
+ wavs, paths = batch
+ for wav, path in zip(wavs, paths):
+ wav = wav.to(device)
+ latent = model.encode(wav).cpu()
+ out_path = out_dir / (path.stem + ".st")
+ save_file({"audio_z": latent}, str(out_path))
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Batch encode audio files with WavVAE."
+ )
+ parser.add_argument(
+ "input_dataset",
+ type=Path,
+ help="Text file listing paths to audio files (one per line)",
+ )
+ parser.add_argument(
+ "checkpoint", type=Path, help="Path to WavVAE checkpoint directory"
+ )
+ parser.add_argument(
+ "output_dataset", type=Path, help="Directory to save Safetensors latents"
+ )
+ parser.add_argument("--in_column", type=str, default="audio")
+ parser.add_argument("--out_column", type=str, default="audio_z")
+ parser.add_argument(
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
+ )
+ parser.add_argument("--split", type=str, default=None)
+ parser.add_argument(
+ "--num_workers", type=int, default=1, help="Number of DataLoader workers"
+ )
+ parser.add_argument(
+ "--num_shards", type=int, default=1, help="Number of DataLoader workers"
+ )
+ parser.add_argument(
+ "--shard_index", type=int, default=0, help="Number of DataLoader workers"
+ )
+ parser.add_argument("--from_file", action="store_true")
+ parser.add_argument("--from_files", action="store_true")
+ parser.add_argument("--no_resample", action="store_false")
+ parser.add_argument("--from_hub", action="store_true")
+ parser.add_argument("--file_prefix", type=str, default=None)
+ args = parser.parse_args()
+ device = torch.device(args.device)
+
+ # Load model
+ wavvae = WavVAE.from_pretrained_local(args.checkpoint)
+ wavvae = wavvae.to(device).eval()
+ target_sr = wavvae.sampling_rate
+ # Prepare dataset and dataloader
+ if args.from_hub:
+ dataset = load_dataset(str(args.input_dataset), args.split)
+ else:
+ dataset = load_from_disk(str(args.input_dataset), args.split)
+ #
+
+ if args.num_shards > 1:
+ dataset = dataset.shard(num_shards=args.num_shards, index=args.shard_index)
+ if args.from_file:
+ def map_fn(audio_file_path):
+ if args.file_prefix is not None:
+ audio_file_path = args.file_prefix + "/" + audio_file_path
+ wav, sr = torchaudio.load(audio_file_path)
+ wav = resample(wav, sr, target_sr)
+ wav = wav.mean(dim=0, keepdim=True)
+ if not args.no_resample:
+ wav = resample(wav, sr, target_sr)
+ with torch.inference_mode():
+ latent = wavvae.encode(wav.to(device))
+ return {"audio_z": latent}
+
+ dataset = dataset.map(map_fn, input_columns=args.in_column)
+ elif args.from_files:
+ def map_fn(audio_file_paths):
+ if args.file_prefix is not None:
+ audio_file_paths = [args.file_prefix + "/" + x for x in audio_file_paths]
+ wav, sr = torchaudio.load(audio_file_paths[0])
+ wavs = [wav.mean(dim=0, keepdim=True)]
+ wavs = wavs + [torchaudio.load(x)[0].mean(dim=0, keepdim=True) for x in audio_file_paths[1:]]
+ wav = torch.cat(wavs, dim=1)
+ if not args.no_resample:
+ wav = resample(wav, sr, target_sr)
+ with torch.inference_mode():
+ latent = wavvae.encode(wav.to(device))
+ return {"audio_z": latent}
+
+ dataset = dataset.map(map_fn, input_columns=args.in_column)
+
+ else:
+ dataset = dataset.cast_column(args.in_column, Audio(sampling_rate=target_sr))
+ dataset = dataset.with_format(
+ "torch",
+ columns=args.in_column,
+ )
+
+ def map_fn(audio):
+ with torch.inference_mode():
+ wav = audio["array"].unsqueeze(0).to(device)
+ latent = wavvae.encode(wav)
+ return {"audio_z": latent}
+
+ dataset = dataset.map(
+ map_fn, input_columns=args.in_column, remove_columns=args.in_column
+ )
+
+ dataset.save_to_disk(args.output_dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/infer_zflowae.py b/codec/scripts/infer_zflowae.py
new file mode 100644
index 0000000000000000000000000000000000000000..16fddaebea1f848976f12522ceac0526ce014cf5
--- /dev/null
+++ b/codec/scripts/infer_zflowae.py
@@ -0,0 +1,121 @@
+import argparse
+from pathlib import Path
+
+import librosa
+import soundfile as sf
+import torch
+from safetensors import safe_open
+from safetensors.torch import save_file
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+
+from zcodec.models import ZFlowAutoEncoder
+
+
+class AudioDataset(Dataset):
+ def __init__(self, file_list, target_sr):
+ self.paths = file_list
+ self.target_sr = target_sr
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, idx):
+ path = self.paths[idx]
+ wav, sr = sf.read(str(path))
+ if sr != self.target_sr:
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)
+ wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T]
+ return wav, path
+
+
+class SafeTensorDataset(Dataset):
+ """
+ On __getitem__, opens the safetensor, uses get_slice() to inspect shape,
+ then either drops too-short files (return None) or returns a random subsequence slice.
+ """
+
+ def __init__(
+ self,
+ file_paths: list[str],
+ key: str = "audio_z",
+ ):
+ self.file_paths = file_paths
+ self.key = key
+
+ def __len__(self):
+ return len(self.file_paths)
+
+ def __getitem__(self, idx: int) -> torch.Tensor | None:
+ path = self.file_paths[idx]
+ # open file, get a slice wrapper for full tensor
+ with safe_open(path, framework="pt") as f:
+ tensor = f.get_tensor(self.key)
+
+ return tensor, path
+
+
+@torch.no_grad()
+def encode_batch(model, batch, device, out_dir, save_latent=False):
+ wavs, paths = batch
+ for wav, path in zip(wavs, paths):
+ wav = wav.to(device)
+ latent, indices = model.encode(wav)
+ if save_latent:
+ to_save = latent.cpu()
+ else:
+ to_save = indices.cpu()
+ out_path = out_dir / (path.stem + ".st")
+ save_file({"audio_z": to_save}, str(out_path))
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Batch encode audio files with ZFlowAutoEncoder."
+ )
+ parser.add_argument(
+ "file_list",
+ type=Path,
+ help="Text file listing paths to audio files (one per line)",
+ )
+ parser.add_argument(
+ "checkpoint", type=Path, help="Path to ZFlowAutoEncoder checkpoint directory"
+ )
+ parser.add_argument(
+ "output_dir", type=Path, help="Directory to save Safetensors latents"
+ )
+ parser.add_argument(
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
+ )
+
+ parser.add_argument("--num_workers", type=int, default=4, help="Num. workers")
+ args = parser.parse_args()
+ device = torch.device(args.device)
+
+ # Load model
+ zflowae = ZFlowAutoEncoder.from_pretrained_local(args.checkpoint)
+ zflowae = zflowae.to(device).eval()
+
+ # Prepare dataset and dataloader
+ with open(args.file_list, "r") as f:
+ file_paths = [Path(line.strip()) for line in f if line.strip()]
+ dataset = SafeTensorDataset(file_paths)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=1,
+ num_workers=args.num_workers,
+ collate_fn=lambda x: list(zip(*x)),
+ )
+
+ args.output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Inference loop
+ for batch in tqdm(dataloader):
+ try:
+ encode_batch(zflowae, batch, device, args.output_dir)
+ except Exception as e:
+ print(f"❌ Batch failed: {e}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/scripts/infer_zflowae_hf.py b/codec/scripts/infer_zflowae_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..911f203422320b39336576006797b4f382ba8c8e
--- /dev/null
+++ b/codec/scripts/infer_zflowae_hf.py
@@ -0,0 +1,121 @@
+import argparse
+from pathlib import Path
+
+import librosa
+import soundfile as sf
+import torch
+import torchaudio
+from safetensors.torch import save_file
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+
+from datasets import Audio, load_dataset, load_from_disk
+from zcodec.models import WavVAE, ZFlowAutoEncoder
+
+
+class AudioDataset(Dataset):
+ def __init__(self, file_list, target_sr):
+ self.paths = file_list
+ self.target_sr = target_sr
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, idx):
+ path = self.paths[idx]
+ wav, sr = sf.read(str(path))
+ if sr != self.target_sr:
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)
+ wav = torch.tensor(wav).unsqueeze(0).float() # shape: [1, T]
+ return wav, path
+
+
+@torch.no_grad()
+def encode_batch(model, batch, device, out_dir, save_latent=True):
+ wavs, paths = batch
+ for wav, path in zip(wavs, paths):
+ wav = wav.to(device)
+ latent, indices = model.encode(wav)
+ if save_latent:
+ to_save = latent.cpu()
+ else:
+ save_latent = indices.cpu()
+ out_path = out_dir / (path.stem + ".st")
+ save_file({"audio_z": to_save}, str(out_path))
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Batch encode audio files with WavVAE."
+ )
+ parser.add_argument(
+ "input_dataset",
+ type=Path,
+ help="Text file listing paths to audio files (one per line)",
+ )
+ parser.add_argument(
+ "checkpoint", type=Path, help="Path to zflowae checkpoint directory"
+ )
+ parser.add_argument(
+ "output_dataset", type=Path, help="Directory to save Safetensors latents"
+ )
+ parser.add_argument("--in_column", type=str, default="audio_z")
+ parser.add_argument("--out_column", type=str, default="audio_latent")
+ parser.add_argument(
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
+ )
+ parser.add_argument("--split", type=str, default="all")
+ parser.add_argument(
+ "--num_workers", type=int, default=1, help="Number of DataLoader workers"
+ )
+ parser.add_argument("--from_file", action="store_true")
+ parser.add_argument("--from_hub", action="store_true")
+ parser.add_argument("--file_prefix", type=str, default=None)
+ args = parser.parse_args()
+ device = torch.device(args.device)
+
+ # Load model
+ zflowae = ZFlowAutoEncoder.from_pretrained_local(args.checkpoint)
+ zflowae = zflowae.to(device).eval()
+
+ # Prepare dataset and dataloader
+ if args.from_hub:
+ dataset = load_dataset(str(args.input_dataset), args.split)
+ else:
+ dataset = load_from_disk(str(args.input_dataset), args.split)
+ #
+
+ if args.from_file:
+ raise NotImplemented
+
+ def map_fn(audio_file_path):
+ if args.file_prefix is not None:
+ audio_file_path = args.file_prefix+"/"+audio_file_path
+ wav, sr = torchaudio.load(audio_file_path)
+ wav = wav.mean(dim=0, keepdim=True)
+ with torch.inference_mode():
+ latent = zflowae.encode(wav.to(device))
+ return {"audio_z": latent}
+
+ dataset = dataset.map(map_fn, input_columns=args.in_column)
+ else:
+ dataset = dataset.with_format(
+ "torch",
+ columns=args.in_column,
+ )
+
+ def map_fn(audio):
+ with torch.inference_mode():
+ audio_z = audio.to(device)
+ latent, _ = zflowae.encode(audio_z)
+ return {args.out_column: latent}
+
+ dataset = dataset.map(
+ map_fn, input_columns=args.in_column, remove_columns=args.in_column
+ )
+
+ dataset.save_to_disk(args.output_dataset)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/codec/train_patchvae.py b/codec/train_patchvae.py
new file mode 100755
index 0000000000000000000000000000000000000000..40bfbf22495f6a30cbd9a4d967e385587b8ca723
--- /dev/null
+++ b/codec/train_patchvae.py
@@ -0,0 +1,232 @@
+import json
+import math
+from dataclasses import asdict
+from pathlib import Path
+from typing import Literal
+
+import pytorch_lightning as pl
+import torch
+from safetensors.torch import safe_open, save_file
+from torch import nn
+from torch.optim.lr_scheduler import LambdaLR
+from torchaudio.transforms import Resample
+from transformers import WavLMModel
+
+from .models import PatchVAE, PatchVAEConfig, WavVAE
+from .models.components.convnext import SwiGLU
+from .models.patchvae.modules import convnext_factory
+
+
+def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
+ def lr_lambda(step):
+ if step < warmup_steps:
+ return step / max(1, warmup_steps)
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
+ return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
+
+ return lr_lambda
+
+
+class TrainPatchVAE(pl.LightningModule):
+ def __init__(
+ self,
+ config: PatchVAEConfig,
+ lr: float = 1e-4,
+ end_lr: float | None = None,
+ weight_decay: float = 0.01,
+ cfg_drop_rate: float = 0.0,
+ cfg_drop_rate_per_sample: float = 0.1,
+ total_steps: int = 2_000_000,
+ warmup_steps: int = 0,
+ mean_std_from_file: str | None = None,
+ train_sigma: float = 1e-4,
+ kl_div_factor: float = 1e-4,
+ wavvae_pretrained_path: str | None = None,
+ drop_vae_rate: float = 0.0,
+ ssl_repa_factor: float = 1.0,
+ ssl_repa_head: bool = False,
+ ssl_repa_head_type: Literal["convnext", "swiglu"] = "swiglu",
+ ssl_repa_head_target: Literal["flow", "prior"] = "prior",
+ ssl_repa_head_num_layers: int = 4,
+ ssl_repa_source: Literal["wavlm-base-plus"] = "wavlm-base-plus",
+ ssl_feat_dim: int = 768,
+ ):
+ super().__init__()
+ self.save_hyperparameters()
+
+ if mean_std_from_file is not None:
+ with safe_open(mean_std_from_file, framework="pt") as f:
+ config.latent_scaling = (
+ f.get_tensor("mean").tolist(),
+ f.get_tensor("std").tolist(),
+ )
+
+ self.patchvae = PatchVAE(config)
+ self.apply(self._init_weights)
+
+ self.wavvae = None
+ if wavvae_pretrained_path is not None:
+ self.wavvae = WavVAE.from_pretrained_local(wavvae_pretrained_path).eval()
+
+ self.ssl_model = None
+ if ssl_repa_head:
+ if ssl_repa_source == "wavlm-base-plus":
+ self.ssl_model = WavLMModel.from_pretrained(
+ "microsoft/wavlm-base-plus"
+ ).eval()
+ self.ssl_resampling = Resample(
+ orig_freq=self.wavvae.sampling_rate, new_freq=16000
+ )
+
+ for p in self.ssl_model.parameters():
+ p.requires_grad = False
+
+ if ssl_repa_head_type == "convnext":
+ self.ssl_repa_head = nn.Sequential(
+ convnext_factory(
+ config.hidden_dim,
+ ssl_repa_head_num_layers,
+ ),
+ nn.Linear(config.hidden_dim, ssl_feat_dim),
+ )
+ elif ssl_repa_head_type == "swiglu":
+ self.ssl_repa_head = nn.Sequential(
+ *[
+ SwiGLU(config.hidden_dim)
+ for _ in range(ssl_repa_head_num_layers)
+ ],
+ nn.Linear(config.hidden_dim, ssl_feat_dim),
+ )
+
+ else:
+ self.ssl_repa_head = None
+
+ def _init_weights(self, m: nn.Module):
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def configure_optimizers(self):
+ params = [{"params": self.patchvae.parameters()}]
+ if self.ssl_repa_head is not None:
+ params += [{"params": self.ssl_repa_head.parameters()}]
+ opt = torch.optim.AdamW(
+ params,
+ lr=self.hparams.lr,
+ weight_decay=self.hparams.weight_decay,
+ )
+
+ if self.hparams.end_lr is not None:
+ scheduler = LambdaLR(
+ opt,
+ lr_lambda=cosine_schedule_with_warmup(
+ warmup_steps=self.hparams.warmup_steps,
+ total_steps=self.hparams.total_steps,
+ start_lr=self.hparams.lr,
+ end_lr=self.hparams.end_lr,
+ ),
+ )
+ scheduler = {"scheduler": scheduler, "interval": "step"}
+ return [opt], [scheduler]
+ return opt
+
+ def save_model_weights_and_config(
+ self,
+ dir: str | None,
+ model_filename: str = "model.st",
+ config_filename: str = "config.json",
+ ):
+ cfg = self.hparams.config
+ model_path = Path(dir) / model_filename
+ save_file(self.patchvae.state_dict(), model_path.with_suffix(".st"))
+ with open(Path(dir) / config_filename, "w") as f:
+ json.dump(asdict(cfg), f, indent=2)
+
+ def training_step(self, batch: dict[str, torch.Tensor], batch_idx):
+ z = batch["audio_z"]
+ t = torch.rand(z.shape[0], device=z.device)
+ drop_cond_rate = self.hparams.cfg_drop_rate_per_sample
+ drop_vae_rate = self.hparams.drop_vae_rate
+ flow_loss, ae_loss, prior = self.patchvae(
+ z,
+ t,
+ sigma=self.hparams.train_sigma,
+ drop_vae_rate=drop_vae_rate,
+ drop_cond_rate=drop_cond_rate,
+ )
+
+ self.log("train_flow_loss", flow_loss, prog_bar=True)
+
+ total_loss = flow_loss
+ if ae_loss.get("kl_div") is not None:
+ kl_div = ae_loss.get("kl_div")
+ self.log("train_kl_div", kl_div, prog_bar=True)
+ total_loss += self.hparams.kl_div_factor * kl_div
+
+ for x in ["_mu_mean", "_mu_std", "_logvar_mean", "_logvar_std"]:
+ stat = ae_loss.get(x)
+ if stat is not None:
+ self.log(x, stat, prog_bar=False)
+
+ if self.ssl_repa_head is not None:
+ target = self.hparams.ssl_repa_head_target
+ if target == "prior":
+ target = prior
+ elif target == "flow":
+ raise NotImplementedError
+ with torch.inference_mode():
+ wav = self.wavvae.decode(z)
+ wav = self.ssl_resampling(wav)
+ wav = torch.nn.functional.pad(wav, (40, 40))
+ ssl_feats = self.ssl_model(wav, output_hidden_states=True).hidden_states
+ ssl_feat = ssl_feats[10]
+ ssl_feat = torch.nn.functional.avg_pool1d(
+ ssl_feat.transpose(-1, -2),
+ kernel_size=8,
+ stride=4,
+ padding=2,
+ ).transpose(-1, -2)
+
+ ssl_feat = ssl_feat.clone()
+
+ B, N, D = ssl_feat.shape
+
+ repa_pred = self.ssl_repa_head(target)
+ ssl_repa_loss = nn.functional.cosine_embedding_loss(
+ repa_pred.reshape(-1, D),
+ ssl_feat.reshape(-1, D),
+ torch.ones(1).to(repa_pred),
+ )
+ total_loss += self.hparams.ssl_repa_factor * ssl_repa_loss
+ self.log("train_repa_loss", ssl_repa_loss, prog_bar=True)
+
+ return total_loss
+
+ def validation_step(self, batch: dict[str, torch.Tensor], batch_idx):
+ z = batch["audio_z"]
+ t = (
+ torch.ones(z.shape[0], device=z.device)
+ * batch_idx
+ / self.trainer.num_val_batches[0]
+ )
+ flow_loss, ae_loss, prior = self.patchvae(
+ z, t, sigma=self.hparams.train_sigma, drop_cond=False
+ )
+ self.log("val_flow_loss", flow_loss, prog_bar=True)
+
+ total_loss = flow_loss
+ if ae_loss.get("kl_div") is not None:
+ kl_div = ae_loss.get("kl_div")
+ self.log("val_kl_div", kl_div, prog_bar=True)
+ total_loss += self.hparams.kl_div_factor * kl_div
+ return total_loss
+
+if __name__ == "__main__":
+ from pytorch_lightning.cli import LightningCLI
+
+ LightningCLI(
+ TrainPatchVAE,
+ save_config_callback=None,
+ parser_kwargs={"parser_mode": "omegaconf"},
+ )
diff --git a/codec/train_wavvae.py b/codec/train_wavvae.py
new file mode 100755
index 0000000000000000000000000000000000000000..0392d3b846ef8760f306363a701d093a28e818e2
--- /dev/null
+++ b/codec/train_wavvae.py
@@ -0,0 +1,304 @@
+import json
+import math
+from dataclasses import asdict
+from pathlib import Path
+from typing import Optional
+
+import pytorch_lightning as pl
+import torch
+import transformers
+from pytorch_lightning.callbacks import ModelCheckpoint
+from pytorch_lightning.cli import LightningCLI
+from pytorch_lightning.loggers.wandb import WandbLogger
+from safetensors.torch import save_file
+from torch.nn.utils import clip_grad_norm_
+
+from codec.models import WavVAE, WavVAEConfig
+from codec.models.wavvae.discriminators import (MultiPeriodDiscriminator,
+ MultiResolutionDiscriminator)
+from codec.models.wavvae.loss import (DiscriminatorLoss, FeatureMatchingLoss,
+ GeneratorLoss,
+ MelSpecReconstructionLoss)
+
+
+class TrainWavVAE(pl.LightningModule):
+ def __init__(
+ self,
+ config: WavVAEConfig,
+ sample_rate: int,
+ initial_learning_rate: float,
+ num_warmup_steps: int = 0,
+ mel_loss_coeff: float = 45,
+ mrd_loss_coeff: float = 1.0,
+ kl_div_coeff: float = 1e-5,
+ pretrain_mel_steps: int = 0,
+ decay_mel_coeff: bool = False,
+ clip_grad_norm: float | None = None,
+ f_min: int = 0,
+ f_max: Optional[int] = None,
+ mrd_fft_sizes: tuple[int, int, int] = (2048, 1024, 512),
+ mel_hop_length: int = 256,
+ log_audio_every_n_epoch: int = 5,
+ log_n_audio_batches: int = 32,
+ ):
+ super().__init__()
+
+ self.save_hyperparameters()
+ self.wavvae = WavVAE(config)
+ self.multiperioddisc = MultiPeriodDiscriminator()
+ self.multiresddisc = MultiResolutionDiscriminator(
+ fft_sizes=tuple(mrd_fft_sizes)
+ )
+
+ self.disc_loss = DiscriminatorLoss()
+ self.gen_loss = GeneratorLoss()
+ self.feat_matching_loss = FeatureMatchingLoss()
+ self.melspec_loss = MelSpecReconstructionLoss(
+ sample_rate=sample_rate,
+ f_min=f_min,
+ f_max=f_max,
+ hop_length=mel_hop_length,
+ )
+
+ self.train_discriminator = False
+ self.automatic_optimization = False
+ self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff
+
+ def save_model_weights_and_config(
+ self,
+ dir: str | None,
+ model_filename: str = "model.st",
+ config_filename: str = "config.json",
+ ):
+ cfg = self.hparams.config
+ model_path = Path(dir) / model_filename
+ save_file(self.wavvae.state_dict(), model_path)
+ with open(Path(dir) / config_filename, "w") as f:
+ json.dump(asdict(cfg), f, indent=2)
+
+ def configure_optimizers(self):
+ disc_params = [
+ {"params": self.multiperioddisc.parameters()},
+ {"params": self.multiresddisc.parameters()},
+ ]
+ gen_params = [
+ {"params": self.wavvae.parameters()},
+ ]
+
+ opt_disc = torch.optim.AdamW(
+ disc_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9)
+ )
+ opt_gen = torch.optim.AdamW(
+ gen_params, lr=self.hparams.initial_learning_rate, betas=(0.8, 0.9)
+ )
+
+ max_steps = self.trainer.max_steps // 2
+ scheduler_disc = transformers.get_cosine_schedule_with_warmup(
+ opt_disc,
+ num_warmup_steps=self.hparams.num_warmup_steps,
+ num_training_steps=max_steps,
+ )
+ scheduler_gen = transformers.get_cosine_schedule_with_warmup(
+ opt_gen,
+ num_warmup_steps=self.hparams.num_warmup_steps,
+ num_training_steps=max_steps,
+ )
+
+ return (
+ [opt_disc, opt_gen],
+ [
+ {"scheduler": scheduler_disc, "interval": "step"},
+ {"scheduler": scheduler_gen, "interval": "step"},
+ ],
+ )
+
+ def forward(self, audio_input, **kwargs):
+ audio_output, kl_div = self.wavvae(audio_input)
+
+ return audio_output, kl_div
+
+ def training_step(self, batch, batch_idx, **kwargs):
+ audio_input = batch
+
+ opt_disc, opt_gen = self.optimizers()
+
+ if self.train_discriminator:
+ with torch.no_grad():
+ audio_hat, kl_div = self(audio_input, **kwargs)
+
+ real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(
+ y=audio_input,
+ y_hat=audio_hat,
+ **kwargs,
+ )
+ real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(
+ y=audio_input,
+ y_hat=audio_hat,
+ **kwargs,
+ )
+ loss_mp, loss_mp_real, _ = self.disc_loss(
+ disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
+ )
+ loss_mrd, loss_mrd_real, _ = self.disc_loss(
+ disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
+ )
+ loss_mp /= len(loss_mp_real)
+ loss_mrd /= len(loss_mrd_real)
+ loss_disc = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd
+
+ self.log("discriminator/total", loss_disc, prog_bar=True)
+ self.log("discriminator/multi_period_loss", loss_mp)
+ self.log("discriminator/multi_res_loss", loss_mrd)
+
+ opt_disc.zero_grad()
+ self.manual_backward(loss_disc)
+ if self.hparams.clip_grad_norm is not None:
+ max_norm = self.hparams.clip_grad_norm
+ clip_grad_norm_(self.multiperioddisc.parameters(), max_norm=max_norm)
+ clip_grad_norm_(self.multiresddisc.parameters(), max_norm=max_norm)
+ opt_disc.step()
+
+ audio_hat, kl_div = self(audio_input, **kwargs)
+ if self.train_discriminator:
+ _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc(
+ y=audio_input,
+ y_hat=audio_hat,
+ **kwargs,
+ )
+ _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc(
+ y=audio_input,
+ y_hat=audio_hat,
+ **kwargs,
+ )
+ loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
+ loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd)
+ loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp)
+ loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
+ loss_fm_mp = self.feat_matching_loss(
+ fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp
+ ) / len(fmap_rs_mp)
+ loss_fm_mrd = self.feat_matching_loss(
+ fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd
+ ) / len(fmap_rs_mrd)
+
+ self.log("generator/multi_period_loss", loss_gen_mp)
+ self.log("generator/multi_res_loss", loss_gen_mrd)
+ self.log("generator/feature_matching_mp", loss_fm_mp)
+ self.log("generator/feature_matching_mrd", loss_fm_mrd)
+
+ self.log("generator/kl_div", kl_div)
+
+ mel_loss = self.melspec_loss(audio_hat, audio_input)
+ loss = (
+ loss_gen_mp
+ + self.hparams.mrd_loss_coeff * loss_gen_mrd
+ + loss_fm_mp
+ + self.hparams.mrd_loss_coeff * loss_fm_mrd
+ + self.mel_loss_coeff * mel_loss
+ + self.hparams.kl_div_coeff * kl_div
+ )
+
+ self.log("generator/total_loss", loss, prog_bar=True)
+ self.log("mel_loss_coeff", self.mel_loss_coeff)
+ self.log("generator/mel_loss", mel_loss)
+
+ opt_gen.zero_grad()
+ self.manual_backward(loss)
+ if self.hparams.clip_grad_norm is not None:
+ max_norm = self.hparams.clip_grad_norm
+ clip_grad_norm_(self.wavvae.parameters(), max_norm=max_norm)
+ opt_gen.step()
+
+ def validation_step(self, batch, batch_idx, **kwargs):
+ audio_input = batch
+ audio_hat, _ = self(audio_input, **kwargs)
+
+ if self.current_epoch % self.hparams.log_audio_every_n_epoch == 0:
+ wavs = [x.numpy(force=True) for x in audio_hat.unbind(0)]
+ if batch_idx == 0:
+ self._audios_to_log = wavs
+ if batch_idx < self.hparams.log_n_audio_batches:
+ self._audios_to_log += wavs
+ elif batch_idx == self.hparams.log_n_audio_batches:
+ self.logger.log_audio(
+ "audio",
+ self._audios_to_log,
+ step=self.global_step,
+ sample_rate=[
+ self.wavvae.sampling_rate
+ for _ in range(len(self._audios_to_log))
+ ],
+ )
+
+ mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1))
+ total_loss = mel_loss
+
+ return {
+ "val_loss": total_loss,
+ "mel_loss": mel_loss,
+ "audio_input": audio_input[0],
+ "audio_pred": audio_hat[0],
+ }
+
+ @property
+ def global_step(self):
+ """
+ Override global_step so that it returns the total number of batches processed
+ """
+ return self.trainer.fit_loop.epoch_loop.total_batch_idx
+
+ def on_train_batch_start(self, *args):
+ if self.global_step >= self.hparams.pretrain_mel_steps:
+ self.train_discriminator = True
+ else:
+ self.train_discriminator = False
+
+ def on_train_batch_end(self, *args):
+ def mel_loss_coeff_decay(current_step, num_cycles=0.5):
+ max_steps = self.trainer.max_steps // 2
+ if current_step < self.hparams.num_warmup_steps:
+ return 1.0
+ progress = float(current_step - self.hparams.num_warmup_steps) / float(
+ max(1, max_steps - self.hparams.num_warmup_steps)
+ )
+ return max(
+ 0.0,
+ 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
+ )
+
+ if self.hparams.decay_mel_coeff:
+ self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(
+ self.global_step + 1
+ )
+
+if __name__ == "__main__":
+ class WavVAECLI(LightningCLI):
+ def after_instantiate_classes(self):
+ hparams = self.model.hparams
+ kl_factor = "{:.1e}".format(hparams.kl_div_coeff)
+ latent_dim = hparams.config["latent_dim"]
+ frame_rate = self.model.wavvae.frame_rate
+ dataset_name = (
+ Path(self.datamodule.train_config.filelist_path).with_suffix("").name
+ )
+
+ name = f"WavVAE_kl{kl_factor}_framerate{frame_rate}hz_latentdim{latent_dim}_dataset{dataset_name}"
+
+ if self.trainer.logger:
+ logger = WandbLogger(
+ log_model=False,
+ project="codec",
+ name=name,
+ )
+ model_checkpoint_cb = ModelCheckpoint(
+ monitor="generator/mel_loss",
+ dirpath="checkpoints/wavvae",
+ filename=name + "_epoch{epoch:02d}",
+ save_last=True,
+ )
+ self.trainer.callbacks.append(model_checkpoint_cb)
+
+ WavVAECLI(
+ save_config_kwargs={"overwrite": True},
+ parser_kwargs={"parser_mode": "omegaconf"},
+ )
diff --git a/pardi_speech.py b/pardi_speech.py
new file mode 100644
index 0000000000000000000000000000000000000000..66488ee926f4990844b6cbfde7bf34942cff4673
--- /dev/null
+++ b/pardi_speech.py
@@ -0,0 +1,208 @@
+import json
+import string
+from dataclasses import asdict, dataclass
+from pathlib import Path
+
+import torch
+from safetensors.torch import load_file
+from torch.nn.utils.rnn import pad_sequence
+
+from codec.models import PatchVAE
+from tts.model.cache_utils import FLACache
+from tts.text_processor import BasicTextProcessor
+from tts.tools import sequence_mask
+from tts.tts import ARTTSModel
+
+
+@dataclass
+class VelocityHeadSamplingParams:
+ """
+ Velocity head sampling parameters
+
+ Attributes:
+
+ cfg (float): CFG factor against unconditional prediction.
+ cfg_ref (float): CFG factor against a reference (to be used with a cache of size 2*batch_size and unfold).
+ temperature (float): scale factor of z0 ~ 𝒩(0,1)
+ num_steps (int): number of ODE steps
+ solver (str): parameter passed to NeuralODE
+ sensitivity (str): parameter passed to NeuralODE
+ """
+ cfg: float = 1.3
+ cfg_ref: float = 1.5
+ temperature: float = 0.9
+ num_steps: int = 13
+ solver: str = "euler"
+ sensitivity: str = "adjoint"
+
+
+@dataclass
+class PatchVAESamplingParams:
+ """
+ PatchVAE sampling parameters
+
+ Attributes:
+
+ cfg (float): CFG factor against unconditional prediction.
+ temperature (float): scale factor of z0 ~ 𝒩(0,1)
+ num_steps (int): number of ODE steps
+ solver (str): parameter passed to NeuralODE
+ sensitivity (str): parameter passed to NeuralODE
+ """
+
+ cfg: float = 2.0
+ temperature: float = 1.0
+ num_steps: int = 10
+ solver: str = "euler"
+ sensitivity: str = "adjoint"
+
+
+class PardiSpeech:
+ tts: ARTTSModel
+ patchvae: PatchVAE
+ text_processor: BasicTextProcessor
+
+ def __init__(
+ self,
+ tts: ARTTSModel,
+ patchvae: PatchVAE,
+ text_processor: BasicTextProcessor,
+ ):
+ self.tts = tts
+ self.patchvae = patchvae
+ self.text_processor = text_processor
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str,
+ map_location: str = "cpu",
+ ):
+ if Path(pretrained_model_name_or_path).exists():
+ path = pretrained_model_name_or_path
+ else:
+ from huggingface_hub import snapshot_download
+
+ path = snapshot_download(pretrained_model_name_or_path)
+
+ with open(Path(path) / "config.json", "r") as f:
+ config = json.load(f)
+
+ artts_model, artts_config = ARTTSModel.instantiate_from_config(config)
+ state_dict = load_file(
+ Path(path) / "model.st",
+ device=map_location,
+ )
+ artts_model.load_state_dict(state_dict, assign=True)
+ patchvae = PatchVAE.from_pretrained(
+ artts_config.patchvae_path,
+ map_location=map_location,
+ )
+ text_processor = BasicTextProcessor(
+ str(Path(path) / "pretrained_tokenizer.json")
+ )
+
+ return cls(artts_model, patchvae, text_processor)
+
+ def encode_reference(self, wav: torch.Tensor, sr: int):
+ import torchaudio
+
+ new_freq = self.patchvae.wavvae.sampling_rate
+ wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=new_freq)
+ return self.patchvae.encode(wav)
+
+ @property
+ def sampling_rate(self):
+ return self.patchvae.wavvae.sampling_rate
+
+ def text_to_speech(
+ self,
+ text: str,
+ prefix: tuple[str, torch.Tensor] | None = None,
+ patchvae_sampling_params: PatchVAESamplingParams | None = None,
+ velocity_head_sampling_params: VelocityHeadSamplingParams | None = None,
+ prefix_separator: str = ". ",
+ max_seq_len: int = 600,
+ stop_threshold: float = 0.5,
+ cache: FLACache | None = None,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ text: str
+ The text to synthesize.
+
+ prefix: tuple[str, torch.Tensor] | None
+ A pair (text, speech) consisting of a reference speech excerpt encoded (see encode_reference) and its corresponding text transcription. Synthesis is performed by continuing the prefix. If no prefix is given, the first frame is randomly sampled.
+
+ patchvae_sampling_params: PatchVAESamplingParams
+ PatchVAE sampling parameters
+
+ velocity_head_sampling_params: VelocityHeadSamplingParams
+ VelocityHead sampling parameters (AR sampling)
+
+ prefix_separator: str
+ The separator that joins the prefix text to the target text.
+
+ max_seq_len: int
+ The maximum number of latent to generate.
+
+ stop_threshold: float
+ Threshold value at which AR prediction stops.
+ """
+
+ device = next(self.tts.parameters()).device
+
+ if type(text) is str:
+ text = [text]
+
+ if prefix is not None:
+ prefix_text, prefix_speech = prefix
+ prefix_text = prefix_text.strip().rstrip(string.punctuation)
+ if prefix_text != "":
+ text = [prefix_text + prefix_separator + t for t in text]
+
+ prefix_speech = prefix_speech.repeat(len(text), 1, 1)
+ else:
+ _, audio_latent_sz = self.tts.audio_embd.weight.shape
+ prefix_speech = torch.randn(len(text), 1, audio_latent_sz, device=device)
+
+ # if self.bos:
+ # text = "[BOS]" + text
+
+ # if self.eos:
+ # text = text + "[EOS]"
+ text_ids = [torch.LongTensor(self.text_processor(x + "[EOS]")) for x in text]
+ text_pre_mask = sequence_mask(torch.tensor([x.shape[0] for x in text_ids])).to(device)
+ text_mask = text_pre_mask[:, None] * text_pre_mask[..., None]
+ crossatt_mask = text_pre_mask[:, None,None]
+
+ text_ids = pad_sequence(text_ids, batch_first=True)
+
+ if velocity_head_sampling_params is None:
+ velocity_head_sampling_params = VelocityHeadSamplingParams()
+
+ if patchvae_sampling_params is None:
+ patchvae_sampling_params = PatchVAESamplingParams()
+
+
+ with torch.inference_mode():
+ _, predictions = self.tts.generate(
+ text_ids.to(device),
+ text_mask=text_mask,
+ crossatt_mask=crossatt_mask,
+ prefix=prefix_speech.to(device),
+ max_seq_len=max_seq_len,
+ sampling_params=asdict(velocity_head_sampling_params),
+ stop_threshold=stop_threshold,
+ cache=cache,
+ device=device,
+ **kwargs,
+ )
+ wavs = [self.patchvae.decode(
+ p,
+ **asdict(patchvae_sampling_params),
+ ) for p in predictions]
+
+ return wavs, predictions
diff --git a/tts/__pycache__/groupdataset.cpython-312.pyc b/tts/__pycache__/groupdataset.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2c40da29cddbd1329fc33fa0051854b863577ea
Binary files /dev/null and b/tts/__pycache__/groupdataset.cpython-312.pyc differ
diff --git a/tts/__pycache__/text_processor.cpython-312.pyc b/tts/__pycache__/text_processor.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15abe599201f2508fe8648fbbaf78ed67b80d501
Binary files /dev/null and b/tts/__pycache__/text_processor.cpython-312.pyc differ
diff --git a/tts/__pycache__/tools.cpython-312.pyc b/tts/__pycache__/tools.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..953486f643e9b0b17bc7711ea8cdf9d8d5b7ea16
Binary files /dev/null and b/tts/__pycache__/tools.cpython-312.pyc differ
diff --git a/tts/__pycache__/train_tts.cpython-312.pyc b/tts/__pycache__/train_tts.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1200571bbdc82d9a0246dbaa6dbadb1178ce38e7
Binary files /dev/null and b/tts/__pycache__/train_tts.cpython-312.pyc differ
diff --git a/tts/__pycache__/tts.cpython-312.pyc b/tts/__pycache__/tts.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2a5d73b089ca43e468ed60c696d8ffca427fa4e
Binary files /dev/null and b/tts/__pycache__/tts.cpython-312.pyc differ
diff --git a/tts/groupdataset.py b/tts/groupdataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c03108f382fb2f732ad473407e91db3e8bf646
--- /dev/null
+++ b/tts/groupdataset.py
@@ -0,0 +1,647 @@
+import math
+import random
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from functools import reduce
+from itertools import accumulate
+from random import choices
+from typing import List, Optional, Sequence, Tuple
+
+import pytorch_lightning as ptl
+import torch
+from datasets import (DatasetDict, concatenate_datasets, load_dataset,
+ load_from_disk)
+from einops import rearrange
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import (BatchSampler, DataLoader, Sampler,
+ SubsetRandomSampler)
+from transformers import PreTrainedTokenizerFast
+
+from tts.tools import (audio_to_text_partial_neighbor_mask, packmask_2d,
+ pad_2d_sequence, sequence_mask)
+
+
+class BucketSampler(Sampler[List[int]]):
+ def __init__(
+ self,
+ buckets: List[List[int]],
+ batch_sizes: List[int] | int,
+ bucket_sampling_weights: List[Tuple[float]] = None,
+ drop_last: bool = True,
+ distributed: bool = True, # TODO - implement not distributed as well
+ sample_bucket: Optional[int] = None,
+ seed: int = 123,
+ epoch_seed: bool = True,
+ ):
+ if type(batch_sizes) is int:
+ batch_sizes = [batch_sizes] * len(buckets)
+ else:
+ assert len(buckets) == len(batch_sizes)
+
+ if bucket_sampling_weights is not None:
+ assert len(bucket_sampling_weights) == len(batch_sizes)
+ self.bucket_sampling_weights = bucket_sampling_weights
+ self.num_replicas = torch.distributed.get_world_size()
+ self.rank = torch.distributed.get_rank()
+ self.buckets = [
+ b[self.rank : len(b) - len(b) % self.num_replicas : self.num_replicas]
+ for b in buckets
+ ]
+ self.num_samples = [len(b) // self.num_replicas for b in buckets]
+ self.batch_sizes = batch_sizes
+ self.total_sizes = [
+ ns // bs for ns, bs in zip(self.num_samples, self.batch_sizes)
+ ]
+ self.drop_last = drop_last
+ self.seed = seed
+ self.epoch = 0
+ self.sample_bucket = sample_bucket
+ self.epoch_seed = epoch_seed
+ self.batch_size = batch_sizes[0]
+
+ def set_epoch(self, epoch: int):
+ self.epoch = epoch
+
+ def __len__(self):
+ return sum(self.total_sizes)
+
+ def __iter__(self):
+ generator = torch.Generator()
+ generator.manual_seed(self.seed + self.epoch * self.epoch_seed + self.rank)
+ pool = [
+ BatchSampler(
+ SubsetRandomSampler(b, generator=generator),
+ bs,
+ drop_last=self.drop_last,
+ )
+ for b, bs in zip(self.buckets, self.batch_sizes)
+ ]
+ pool = [iter(b) for b in pool]
+ weights = (
+ [w for w in self.bucket_sampling_weights]
+ if self.bucket_sampling_weights is not None
+ else None
+ )
+ while pool: # sample until all buckets are done
+ idx, bucket = choices(list(enumerate(pool)), weights=weights)[0]
+ try:
+ batch = next(bucket)
+ yield batch
+ except StopIteration:
+ pool.pop(idx) # if bucket is done, throw it
+ if weights is not None:
+ weights.pop(idx)
+
+
+class DatasetFactory(ABC):
+ @abstractmethod
+ def build(self):
+ pass
+
+
+class HiFiTTS2_AudioLatent(DatasetFactory):
+ def __init__(
+ self,
+ path: str | list[str] = "hifitts2_vae8_dataset",
+ duration_column: str = "audio_duration",
+ duration_path: str | None = None,
+ expresso_path: str | None = None,
+ min_dur: float = 3.0,
+ max_dur: float = 20.1,
+ framerate: float = 25.0,
+ ):
+ self.min_dur = min_dur
+ self.max_dur = max_dur
+ self.path = path
+ self.duration_column = duration_column
+ self.duration_path = duration_path
+ self.framerate = framerate
+ self.expresso_path = expresso_path
+
+ def build(self):
+ if type(self.path) is str:
+ self.path = [self.path]
+ datasets = [load_from_disk(x) for x in self.path]
+ dataset = concatenate_datasets(datasets).with_format("torch")
+ if self.duration_path is not None:
+ duration_dataset = load_from_disk(self.duration_path)
+ dataset = concatenate_datasets(
+ [dataset, duration_dataset], axis=1
+ ).with_format("torch")
+ dataset = dataset.filter(
+ lambda dur: dur > self.min_dur and dur < self.max_dur,
+ input_columns=self.duration_column,
+ )
+
+ dataset = dataset.rename_column(self.duration_column, "audio_duration")
+ # dataset = dataset.map(
+ # lambda x: {"audio_duration": x.shape[1] / self.framerate},
+ # input_columns="audio_latent",
+ # ).filter(
+ # lambda dur: dur > self.min_dur and dur < self.max_dur,
+ # input_columns="audio_duration",
+ # )
+ if self.expresso_path is not None:
+ expresso_dataset = load_from_disk(self.expresso_path).with_format("torch")
+
+ dataset = dataset.sort("audio_duration")
+ return DatasetDict({"train": dataset})
+
+
+@dataclass
+class SegmentsCollateArgs:
+ abs_style_intensity: bool = False
+ merge_endpoints: bool = True
+ block_crossatt_mask: bool = True
+ alternate_crossatt_pos: bool = False
+ block_crossatt_past_tokens: int = 0
+ block_crossatt_future_tokens: int = 0
+ eos: bool = True
+ bos: bool = True
+
+
+@dataclass
+class CollateArgs:
+ abs_style_intensity: bool = False
+ random_text_segment: bool = False
+ eos: bool = True
+ bos: bool = True
+ num_stop_tokens: int = 1
+
+def random_log_breakpoints(
+ seq: Sequence, a: int, b: int, gap: bool = False
+) -> List[int]:
+ """
+ Generate random breakpoints in a sequence where the gap X between
+ successive breakpoints satisfies log2(X) ~ Uniform[log2(a), log2(b)].
+ Gaps are then rounded to the nearest integer in [a, b].
+
+ Parameters
+ ----------
+ seq : Sequence
+ The input sequence in which to place breakpoints.
+ a : int
+ Minimum gap (>= 1).
+ b : int
+ Maximum gap (>= a).
+
+ Returns
+ -------
+ List[int]
+ Sorted list of breakpoint indices (0 < idx < len(seq)).
+ """
+ if a < 1 or b < a:
+ raise ValueError("Require 1 <= a <= b")
+ n = len(seq)
+ breakpoints: List[int] = []
+ pos = 0
+
+ while True:
+ # sample U ~ Uniform(log2(a), log2(b))
+ u = random.uniform(math.log2(a), math.log2(b))
+ # map back to X = 2^U, then round to nearest integer
+ x = 2**u
+ gap = int(math.floor(x + 0.5))
+
+ # enforce integer bounds exactly
+ gap = max(a, min(b, gap))
+
+ pos += gap
+ if pos >= n:
+ if gap:
+ breakpoints.append(n - sum(breakpoints))
+ break
+ if gap:
+ breakpoints.append(gap)
+ else:
+ breakpoints.append(pos)
+
+ return breakpoints
+
+
+
+def standalone_collate_latent(
+ batch,
+ tokenizer,
+ abs_style_intensity: bool = False,
+ random_text_segment: bool = False,
+ bos: bool = True,
+ eos: bool = True,
+ num_stop_tokens: int = 1,
+):
+ audio_latent, text = zip(*[(x["audio_latent"], x["text"]) for x in batch])
+ audio_latent = [x.squeeze() for x in audio_latent]
+ text_pp = []
+ for t in text:
+ if bos:
+ t = "[BOS]" + t
+ if eos:
+ t = t + "[EOS]"
+ text_pp.append(t)
+ text_token = [torch.LongTensor(tokenizer.encode(x)) for x in text_pp]
+ xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent))
+
+ stop_token = []
+ text_stop_token = []
+ for x, y in zip(xlen, ylen):
+ tst = torch.zeros(x)
+ st = torch.zeros(y)
+ st_idx = random.randint(1, num_stop_tokens)
+ st[-1] = st_idx
+ tst[-1] = st_idx
+ stop_token.append(st)
+ text_stop_token.append(tst)
+ stop_token = pad_sequence(stop_token, batch_first=True).long()
+ text_stop_token = pad_sequence(text_stop_token, batch_first=True).long()
+
+ x_mask, y_mask = map(
+ lambda x: sequence_mask(x, device="cpu"),
+ (torch.tensor(xlen), torch.tensor(ylen)),
+ )
+
+ text_rel_pos = None
+
+ if random_text_segment:
+ breakpoints = [random_log_breakpoints(t, 32, 256, gap=True) for t in text_token]
+ encoder_mask = pad_2d_sequence([packmask_2d(b, b) for b in breakpoints])
+ text_rel_pos = [torch.cat([torch.arange(bb) for bb in b]) for b in breakpoints]
+ text_rel_pos = pad_sequence(text_rel_pos, batch_first=True)
+ else:
+ encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2)
+ crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2)
+
+ audio_latent, text_token = map(
+ lambda x: pad_sequence(x, batch_first=True, padding_value=0.0),
+ (audio_latent, text_token),
+ )
+
+ if abs_style_intensity:
+ abs_style_intensity = [x["abs_style_intensity"] for x in batch]
+ abs_style_intensity = [
+ torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity
+ ]
+ abs_style_intensity = torch.stack(abs_style_intensity)
+ else:
+ abs_style_intensity = None
+
+ return {
+ "text_token": text_token,
+ "audio_token": audio_latent,
+ "crossatt_mask": crossatt_mask,
+ "encoder_mask": encoder_mask,
+ "y_mask": y_mask,
+ "stop_token": stop_token,
+ "text_stop_token": text_stop_token,
+ "x_len": xlen,
+ "y_len": ylen,
+ "abs_style_intensity": abs_style_intensity,
+ "text_rel_pos": text_rel_pos,
+ }
+
+
+def standalone_collate_latent_segments(
+ batch,
+ tokenizer,
+ abs_style_intensity: bool = False,
+ merge_endpoints: bool = True,
+ block_crossatt_mask: bool = True,
+ block_crossatt_past_tokens: int = 0,
+ block_crossatt_future_tokens: int = 0,
+ alternate_crossatt_pos: bool = False,
+ alternate_crossatt_shift: int = 1000,
+ eos: bool = True,
+ bos: bool = True,
+):
+ audio_latent, text, token_duration = zip(
+ *[(x["audio_latent"], x["text"], x["token_duration"]) for x in batch]
+ )
+ text_pp = []
+ for t in text:
+ if bos:
+ t = "[BOS]" + t
+ if eos:
+ t = t + "[EOS]"
+ text_pp.append(t)
+
+ if merge_endpoints:
+ tokens = [tokenizer.encode(x) for x in text]
+ new_td = []
+ for td in token_duration:
+ begin, end = td[0], td[-1]
+ tdd = td[1:-1]
+ tdd[0] += begin
+ tdd[-1] += end
+ new_td.append(tdd)
+ token_duration = new_td
+ else:
+ tokens = [tokenizer.encode(x) for x in text_pp]
+ segments = [
+ random_segments_from_text_and_durations(t, td.tolist())
+ for t, td in zip(tokens, token_duration)
+ ]
+ bos, eos = map(tokenizer.encode, ("[BOS]", "[EOS]"))
+ audio_segments = []
+ text_segments = []
+ audio_segments_len = []
+ text_segments_len = []
+
+ for aud, seg in zip(audio_latent, segments):
+ tt, at, tt_l, at_l = [], [], [], []
+
+ for i, s in enumerate(seg):
+ ttoken = s["text_token"]
+ if bos:
+ ttoken = bos + ttoken
+ if eos:
+ ttoken = ttoken + eos
+ tt.append(ttoken)
+ a_s = aud[:, s["start"] : s["end"]]
+ at.append(a_s)
+ at_l.append(a_s.shape[1])
+ tt_l.append(len(ttoken))
+
+ audio_segments.append(at)
+ text_segments.append(tt)
+ audio_segments_len.append(at_l)
+ text_segments_len.append(tt_l)
+
+ text_token = [torch.LongTensor(reduce(list.__add__, x)) for x in text_segments]
+ audio_latent = [torch.cat(a_ss, dim=1).squeeze(0) for a_ss in audio_segments]
+
+ xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent))
+ x_mask, y_mask = map(
+ lambda x: sequence_mask(x, device="cpu"),
+ (torch.tensor(xlen), torch.tensor(ylen)),
+ )
+
+ audio_latent, text_token = map(
+ lambda x: pad_sequence(x, batch_first=True, padding_value=0),
+ (audio_latent, text_token),
+ )
+
+ encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2)
+ if block_crossatt_mask:
+ crossatt_mask = [
+ audio_to_text_partial_neighbor_mask(
+ x,
+ y,
+ past_tokens=block_crossatt_past_tokens,
+ future_tokens=block_crossatt_future_tokens,
+ )
+ for x, y in zip(text_segments_len, audio_segments_len)
+ ]
+ crossatt_mask = pad_2d_sequence(crossatt_mask)
+ pad_mask = rearrange(torch.arange(max(ylen)), "n -> 1 n 1") >= rearrange(
+ torch.tensor(ylen), "n -> n 1 1"
+ )
+ else:
+ crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2)
+
+ text_rel_pos = pad_sequence(
+ [torch.cat([torch.arange(x) for x in tsl]) for tsl in text_segments_len],
+ batch_first=True,
+ )
+
+ crossatt_rel_pos = None
+ if alternate_crossatt_pos:
+ crossatt_rel_pos = []
+ for tsl in text_segments_len:
+ rel_pos = []
+ random_shift = int(random.random() < 0.5)
+ for i, x in enumerate(tsl):
+ rel_pos.append(
+ torch.arange(x)
+ + ((random_shift + i) % 2) * alternate_crossatt_shift
+ )
+ crossatt_rel_pos.append(torch.cat(rel_pos))
+ crossatt_rel_pos = pad_sequence(crossatt_rel_pos, batch_first=True)
+
+ audio_rel_pos = pad_sequence(
+ [torch.cat([torch.arange(x) for x in asl]) for asl in audio_segments_len],
+ batch_first=True,
+ )
+
+ stop_token = []
+ for asl in audio_segments_len:
+ sts = []
+ for x in asl:
+ st = torch.zeros(x)
+ st[-1] = 1
+ sts.append(st)
+ stop_token.append(torch.cat(sts))
+ stop_token = pad_sequence(stop_token, batch_first=True).int()
+
+ text_stop_token = []
+ for asl in text_segments_len:
+ sts = []
+ for x in asl:
+ st = torch.zeros(x)
+ st[-1] = 1
+ sts.append(st)
+ text_stop_token.append(torch.cat(sts))
+ text_stop_token = pad_sequence(text_stop_token, batch_first=True).int()
+
+ if abs_style_intensity:
+ abs_style_intensity = [x["abs_style_intensity"] for x in batch]
+ abs_style_intensity = [
+ torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity
+ ]
+ abs_style_intensity = torch.stack(abs_style_intensity)
+ else:
+ abs_style_intensity = None
+
+ return {
+ "text_token": text_token,
+ "audio_token": audio_latent,
+ "crossatt_mask": crossatt_mask,
+ "encoder_mask": encoder_mask,
+ "y_mask": y_mask,
+ "stop_token": stop_token,
+ "text_stop_token": text_stop_token,
+ "x_mask": x_mask,
+ "x_len": xlen,
+ "y_len": ylen,
+ "abs_style_intensity": abs_style_intensity,
+ "text_rel_pos": text_rel_pos,
+ "crossatt_rel_pos": crossatt_rel_pos,
+ "audio_rel_pos": audio_rel_pos,
+ "segments": segments,
+ }
+
+
+
+def random_segments_from_text_and_durations(
+ text,
+ dur,
+ low_bnd: int = 8,
+ up_bnd: int = 384,
+):
+ b = random_log_breakpoints(text, low_bnd, up_bnd)
+ bounds = [0] + b + [len(text)]
+ segs, durs = [], []
+ for a, b in zip(bounds[:-1], bounds[1:]):
+ segs.append(text[a:b])
+ durs.append(sum(dur[a:b]))
+ bounds = [0] + list(accumulate(durs, int.__add__))
+ segs_dicts = []
+ for t, s, e in zip(segs, bounds[:-1], bounds[1:]):
+ segs_dicts.append(
+ {
+ "start": s,
+ "end": e,
+ "text_token": t,
+ }
+ )
+ segs_dicts[-1]["end"] += 1
+ return segs_dicts
+
+
+class LinaDataModule(ptl.LightningDataModule):
+ def __init__(
+ self,
+ path: str | DatasetFactory,
+ quant_layer: list[int],
+ train_batch_size: int = 8,
+ token_by_batch: int | None = None,
+ n_buckets=5,
+ codec_rate_hz: int = 75,
+ num_workers: int = 8,
+ test_size: int = 2000,
+ val_batch_size: int = 8,
+ seed: int = 123,
+ train_test_seed: int = 123,
+ segments: bool = False,
+ segments_args: SegmentsCollateArgs = field(
+ default_factory=lambda: SegmentsCollateArgs()
+ ),
+ collate_args: CollateArgs = field(default_factory=lambda: CollateArgs()),
+ block_mask_segments: bool = False,
+ tokenizer_file=None,
+ trail_end_frame: int | None = None,
+ split="train",
+ add_columns: str | list[str] | None = None,
+ add_text_tokens: list[str] | None = None,
+ type: str = "latent",
+ ):
+ super().__init__()
+
+ self.path = path
+ self.codec_rate_hz = codec_rate_hz
+ self.num_workers = num_workers
+ self.quant_layer = quant_layer
+ self.seed = seed
+ self.segments = segments
+ self.segments_args = segments_args
+ self.collate_args = collate_args
+ self.train_test_seed = train_test_seed
+ self.test_size = test_size
+ self.val_batch_size = val_batch_size
+ self.train_batch_size = train_batch_size
+ self.split = split
+ self.trail_end_frame = trail_end_frame
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
+ if add_text_tokens:
+ self.tokenizer.add_tokens(add_text_tokens)
+ self.add_columns = add_columns
+ self.n_buckets = n_buckets
+ self.token_by_batch = token_by_batch
+ self.type = type
+
+ def setup(self, stage):
+ if isinstance(self.path, DatasetFactory):
+ self.dataset = self.path.build()
+ else:
+ self.dataset = load_dataset(self.path)
+
+ split = self.split
+ columns = [
+ "audio_latent" if self.type == "latent" else "audio_token",
+ "text",
+ "audio_duration",
+ ]
+ if self.add_columns is not None:
+ if type(self.add_columns) is str:
+ self.add_columns = [self.add_columns]
+ columns += self.add_columns
+
+ if self.segments:
+ columns += ["token_duration"]
+ self.collate_fn = lambda x: segments_collate(x, self.tokenizer)
+ else:
+ self.collate_fn = lambda x: standalone_collate(
+ x, self.tokenizer, abs_style_intensity="abs_style_intensity" in columns
+ )
+ self.dataset = (
+ self.dataset[split]
+ .train_test_split(test_size=self.test_size, seed=self.train_test_seed)
+ .select_columns(columns)
+ )
+
+ if self.type == "latent":
+ if self.segments:
+ self.collate_fn = lambda x: standalone_collate_latent_segments(
+ x,
+ self.tokenizer,
+ **self.segments_args,
+ )
+ else:
+ self.collate_fn = lambda x: standalone_collate_latent(
+ x,
+ self.tokenizer,
+ **self.collate_args,
+ )
+
+ def get_buckets_by_quantile(duration, n_quantile, is_sorted=False):
+ if is_sorted:
+ size = len(duration)
+ bucket_size = size // n_quantile
+ buckets = [
+ list(range(i, min(i + bucket_size, size)))
+ for i in range(0, size, bucket_size)
+ ]
+
+ else:
+ idxdur = list(enumerate(duration))
+ idxdur.sort(key=lambda x: x[1])
+ idx, dur = zip(*idxdur)
+ bucket_size = len(idx) // n_quantile
+ buckets = [list(x) for x in zip(*[iter(idx)] * bucket_size)]
+ return buckets
+
+ if self.token_by_batch is not None:
+ train_buckets = get_buckets_by_quantile(
+ self.dataset["train"]["audio_duration"], self.n_buckets
+ )
+ max_audio_durations = [
+ self.dataset["train"]["audio_duration"][x[-1]] for x in train_buckets
+ ]
+
+ batch_sizes = [
+ int(self.token_by_batch // (self.codec_rate_hz * ad))
+ for ad in max_audio_durations
+ ]
+ self.train_batch_sampler = BucketSampler(train_buckets, batch_sizes)
+
+ def train_dataloader(self):
+ if self.token_by_batch is not None:
+ return DataLoader(
+ self.dataset["train"].with_format("torch"),
+ num_workers=self.num_workers,
+ collate_fn=self.collate_fn,
+ batch_sampler=self.train_batch_sampler,
+ )
+ else:
+ return DataLoader(
+ self.dataset["train"].with_format("torch"),
+ num_workers=self.num_workers,
+ batch_size=self.train_batch_size,
+ collate_fn=self.collate_fn,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.dataset["test"].with_format("torch"),
+ batch_size=self.val_batch_size,
+ num_workers=0,
+ collate_fn=self.collate_fn,
+ )
diff --git a/tts/layers/__init__.py b/tts/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tts/layers/__pycache__/__init__.cpython-312.pyc b/tts/layers/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f6997e2139ae796bce914b3e9c64e2f8e5e1c853
Binary files /dev/null and b/tts/layers/__pycache__/__init__.cpython-312.pyc differ
diff --git a/tts/layers/__pycache__/attention.cpython-312.pyc b/tts/layers/__pycache__/attention.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f67df057340e57c0a3b11bdd012c5d92b7d284e4
Binary files /dev/null and b/tts/layers/__pycache__/attention.cpython-312.pyc differ
diff --git a/tts/layers/__pycache__/conv.cpython-312.pyc b/tts/layers/__pycache__/conv.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9176fc8967ee994b3e71cb4cc16f505a4db21d37
Binary files /dev/null and b/tts/layers/__pycache__/conv.cpython-312.pyc differ
diff --git a/tts/layers/__pycache__/ffn.cpython-312.pyc b/tts/layers/__pycache__/ffn.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e122609782808acc10f04294cf8fd069fb9ef186
Binary files /dev/null and b/tts/layers/__pycache__/ffn.cpython-312.pyc differ
diff --git a/tts/layers/attention.py b/tts/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac758ba75c62bdf923d2450988ee39083ed2699a
--- /dev/null
+++ b/tts/layers/attention.py
@@ -0,0 +1,477 @@
+import math
+import os
+import time
+from typing import Literal
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, reduce, repeat
+from fla.models.utils import Cache
+from torch import nn
+from transformers.cache_utils import Cache
+
+
+def apply_causal_sliding_window(mask: torch.Tensor, window_size: int) -> torch.Tensor:
+ B, H, Q, KV = mask.shape
+ device = mask.device
+
+ q_idx = torch.arange(Q, device=device).unsqueeze(1) # (Q, 1)
+ k_idx = torch.arange(KV, device=device).unsqueeze(0) # (1, KV)
+
+ lower_bound = q_idx - (window_size - 1) # (Q, 1), may be negative
+ allowed_2d = (k_idx <= q_idx) & (k_idx >= lower_bound) # (Q, KV), dtype=torch.bool
+
+ allowed_4d = allowed_2d.unsqueeze(0).unsqueeze(0).expand(B, H, Q, KV)
+
+ orig_dtype = mask.dtype
+ if mask.dtype != torch.bool:
+ mask_bool = mask.to(torch.bool)
+ else:
+ mask_bool = mask
+
+ new_mask = mask_bool & allowed_4d
+
+ if orig_dtype != torch.bool:
+ return new_mask.to(orig_dtype)
+ else:
+ return new_mask
+
+
+def precompute_freqs_cis_(
+ t: torch.Tensor,
+ n_elem: int,
+ base: float = 10000,
+) -> torch.Tensor:
+ freqs = 1.0 / (
+ base
+ ** (
+ torch.arange(0, n_elem, 2, device=t.device)[: (n_elem // 2)].float()
+ / n_elem
+ )
+ )
+ freqs = torch.outer(t, freqs)
+ cache = repeat(freqs, "... d -> ... (d 2)")
+ return cache
+
+
+import torch
+from einops import repeat
+
+
+def precompute_freqs_cis(
+ t: torch.Tensor, # shape: (B, T) or (T,)
+ n_elem: int,
+ base: float = 10000,
+) -> torch.Tensor:
+ """
+ Batched version of precompute_freqs_cis.
+
+ Args:
+ t: torch.Tensor, shape (B, T) or (T,)
+ Timesteps to compute frequencies for.
+ n_elem: int
+ Embedding dimension (must be even).
+ base: float
+ Base for frequency computation (default: 10000).
+
+ Returns:
+ cache: torch.Tensor, shape (B, T, n_elem) if batched,
+ (T, n_elem) if unbatched.
+ """
+ if t.dim() == 1: # unbatched
+ t = t.unsqueeze(0) # (1, T)
+
+ B, T = t.shape
+ device = t.device
+
+ # frequencies (half dimension, then expand back)
+ freqs = 1.0 / (
+ base
+ ** (torch.arange(0, n_elem, 2, device=device)[: (n_elem // 2)].float() / n_elem)
+ ) # shape: (n_elem // 2,)
+
+ # outer product for each batch
+ # (B, T, n_elem//2)
+ freqs = torch.einsum("bt,d->btd", t, freqs)
+
+ # duplicate last dim to interleave sin/cos pairs
+ # (B, T, n_elem)
+ cache = repeat(freqs, "... d -> ... (d 2)")
+
+ # if cache.shape[0] == 1: # if originally unbatched
+ # cache = cache.squeeze(0) # (T, n_elem)
+
+ return cache
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return rearrange(x, "... d r -> ... (d r)")
+
+
+def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ out = x * freqs_cis.cos() + rotate_half(x) * freqs_cis.sin()
+ return out
+
+
+def scaled_dot_product_attention(query, key, value, mask=None):
+ scale_factor = 1 / math.sqrt(query.size(-1))
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ if mask is not None:
+ attn_weight.masked_fill_(~mask, -torch.finfo(attn_weight.dtype).max)
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ return attn_weight @ value, attn_weight
+
+
+class SelfAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ layer_idx: int,
+ is_causal: bool = False,
+ sliding_window: int | None = None,
+ ):
+ super().__init__()
+ self.qkv = nn.Linear(dim, 3 * dim)
+ assert dim % num_heads == 0
+ self.heads = num_heads
+ self.is_causal = is_causal
+ self.layer_idx = layer_idx
+ self.output_proj = nn.Linear(dim, dim)
+ self.sliding_window = sliding_window
+ if self.sliding_window is not None:
+ self.is_causal = False
+
+ def forward(
+ self,
+ x,
+ freqs: torch.Tensor | None = None,
+ mask: torch.Tensor | None = None,
+ cache: Cache | None = None,
+ ):
+ B, T, D = x.shape
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
+ q, k, v = map(
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
+ )
+
+ if freqs is not None:
+ q = apply_rotary_emb(q, freqs)
+ k = apply_rotary_emb(k, freqs)
+ if cache is not None:
+ cache.update(attn_state=(k, v), layer_idx=self.layer_idx, offset=T)
+ k, v = cache[self.layer_idx]["attn_state"]
+
+ if self.sliding_window is not None:
+ mask = torch.ones(B, 1, T, T, device=x.device)
+ mask = apply_causal_sliding_window(mask, self.sliding_window)
+
+
+ y = F.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask, is_causal=self.is_causal and T > 1
+ )
+ y = rearrange(y, "b h n d -> b n (h d)")
+ y = self.output_proj(y)
+ return y
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ layer_idx: int | None = None,
+ dropout: float = 0.1,
+ ):
+ super().__init__()
+ assert dim % num_heads == 0
+ self.pre_norm_q = nn.LayerNorm(dim)
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.layer_idx = layer_idx
+ self.heads = num_heads
+ self.dropout_att = dropout
+
+ def _prepare_kv(self, text_hidden_states: torch.Tensor):
+ v = self.ln_v(self.v(text_hidden_states))
+ k = self.ln_k(self.k(text_hidden_states))
+
+ def _query(self, x):
+ return self.q(self.pre_norm_q(q))
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor | None = None,
+ v: torch.Tensor | None = None,
+ mask: torch.Tensor | None = None,
+ output_attention: bool = False,
+ cache: Cache | None = None,
+ **kwargs,
+ ):
+ if v is None:
+ v = k
+ q = self.q(self.pre_norm_q(q))
+ if cache is not None:
+ if cache[self.layer_idx] is not None:
+ ca_state = cache[self.layer_idx]["crossatt_state"]
+ if ca_state is not None:
+ k, v = ca_state
+ else:
+ v = self.v(v)
+ k = self.k(k)
+ cache.update(crossatt_state=(k, v), layer_idx=self.layer_idx)
+ else:
+ v = self.v(v)
+ k = self.k(k)
+ q, k, v = map(
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
+ )
+
+ if mask is not None:
+ if mask.ndim == 3:
+ mask = mask[:, None]
+
+ # if not self.training:
+ if not self.training:
+ x, att = scaled_dot_product_attention(q, k, v, mask=mask)
+ else:
+ x = nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask, dropout_p=self.dropout_att
+ )
+ att = None
+ x = rearrange(x, "b h n d -> b n (h d)")
+ if att is not None:
+ if cache is not None:
+ cache.update(crossatt_weights=att, layer_idx=self.layer_idx)
+ else:
+ self.att = att
+ return x
+
+
+class ConvPos(nn.Module):
+ def __init__(self, dim, max_seq_len=1000, kernel_size=7, n_parallel_codebook=2):
+ super().__init__()
+ self.embed = nn.Embedding(max_seq_len * n_parallel_codebook, dim)
+ self.dw_conv = nn.Conv1d(dim, dim, kernel_size, groups=dim, padding="same")
+ self.max_seq_len = max_seq_len
+ self.n_parallel_codebook = n_parallel_codebook
+
+ def forward(self, x, left_shift=0, random_shift=False):
+ # left_pad = 31 if left_shift > 0 else 0
+ # x = torch.cat((torch.arange(left_shift - left_pad, left_shift).to(x).unsqueeze(0),x, torch.arange(31).to(x).unsqueeze(0)), dim=1).clamp_min_(0)
+ if random_shift:
+ bias = torch.randint(
+ 0,
+ self.n_parallel_codebook,
+ (x.shape[0],),
+ device=x.device,
+ )
+ x = x + bias * self.max_seq_len
+ y = self.embed(x)
+ y = rearrange(y, "b n c -> b c n")
+ y = self.dw_conv(y)
+ y = rearrange(y, "b c n -> b n c") # [:,left_pad:-31]
+ return y
+
+
+class SinPos(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ exp = torch.arange(self.dim // 2, device=x.device)
+ exp = 2 * exp / (self.dim)
+ exp = rearrange(exp, "e -> 1 1 e")
+ x = rearrange(x, "b p -> b p 1")
+ pos = x * torch.pow(10000, -exp)
+ pos = torch.cat((pos, pos + math.pi / 2), dim=2)
+ pos = torch.sin(pos)
+ return pos
+
+
+class BlindCrossAttention(nn.Module):
+ def __init__(
+ self,
+ q_dim,
+ k_dim,
+ att_dim,
+ pos_net,
+ dropout=0.1,
+ pos_dim=64,
+ pos_type="sinusoidal",
+ layer_idx: int | None = None,
+ ):
+ super().__init__()
+ self.q = nn.Linear(q_dim, att_dim)
+ self.k = nn.Linear(k_dim, att_dim)
+ self.v = nn.Linear(k_dim, att_dim)
+ self.pos_net = pos_net
+ if pos_type == "sinusoidal":
+ self.pos_embed = SinPos(pos_dim)
+ elif pos_type == "convolutional":
+ self.pos_embed = ConvPos(pos_dim)
+ self.ln_q = nn.LayerNorm(att_dim)
+ self.ln_k = nn.LayerNorm(att_dim)
+ self.ln_v = nn.LayerNorm(att_dim)
+ self.dropout_att = nn.Dropout(dropout)
+ self.layer_idx = layer_idx
+
+ def _prepare_kv(self, text_hidden_states: torch.Tensor):
+ v = self.ln_v(self.v(text_hidden_states))
+ k = self.ln_k(self.k(text_hidden_states))
+ b, h, j, d = k.shape
+ pos = torch.arange(j, device=k.device).unsqueeze(0)
+ pos_emb = self.pos_embed(pos)
+ return {"k": k, "v": v, "pos_emb": pos_emb}
+
+ def _query(self, x):
+ return self.ln_q(self.q(x))
+
+ def forward(
+ self,
+ q,
+ k,
+ kv_cached=None,
+ mask=None,
+ time_step=None,
+ pos=None,
+ left_shift=0,
+ past_key_values=None,
+ cache=None,
+ **kwargs,
+ ):
+ q = self.ln_q(self.q(q))
+ # if kv_cached is None:
+ # v = self.ln_v(self.v(k))
+ # k = self.ln_k(self.k(k))
+ # else:
+ # k, v = kv_cached
+
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+
+ if cache is not None:
+ if cache[self.layer_idx] is not None:
+ ca_state = cache[self.layer_idx]["crossatt_state"]
+ if ca_state is not None:
+ k, v, pos_emb = ca_state
+ else:
+ # v = self.v(v)
+ # k = self.k(k)
+ v = self.ln_v(self.v(k))
+ k = self.ln_k(self.k(k))
+ pos = torch.arange(k.shape[-2], device=k.device).unsqueeze(0)
+ pos_emb = self.pos_embed(pos, left_shift=left_shift)
+ cache.update(
+ crossatt_state=(k, v, pos_emb), layer_idx=self.layer_idx
+ )
+ else:
+ v = self.ln_v(self.v(k))
+ k = self.ln_k(self.k(k))
+ if pos is None:
+ pos = torch.arange(k.shape[-2], device=k.device).unsqueeze(0)
+ pos_emb = self.pos_embed(pos, left_shift=left_shift)
+
+ q, k, v = map(lambda x: rearrange(x, "b n d -> b 1 n d"), (q, k, v))
+ b, h, j, d = k.shape
+ if self.training:
+ sdpa = lambda q, k, pos: (
+ nn.functional.scaled_dot_product_attention(
+ q, k, pos, attn_mask=mask, dropout_p=self.dropout_att.p
+ ),
+ None,
+ )
+ else:
+ sdpa = lambda q, k, pos: scaled_dot_product_attention(q, k, pos, mask=mask)
+
+ x, att1 = sdpa(q, k, pos_emb.unsqueeze(1))
+ x = rearrange(x, "b 1 n d -> b n d")
+ x = self.pos_net(x, cache=cache)
+ x = rearrange(x, "b n d -> b 1 n d")
+ pos_emb = rearrange(pos_emb, "b n d -> b 1 n d")
+ x, att2 = sdpa(x, pos_emb, v)
+ x = rearrange(x, "b 1 n d -> b n d")
+
+ self.att1 = att1
+ self.att2 = att2
+ if att2 is not None:
+ if cache is not None:
+ cache.update(
+ crossatt_weights=torch.cat((att1, att2), dim=1),
+ layer_idx=self.layer_idx,
+ )
+ return x
+
+
+class ListenReadCrossAttention(nn.Module):
+ def __init__(
+ self,
+ q_dim: int,
+ k_dim: int,
+ att_dim: int,
+ crossatt_type: Literal["listen", "read"],
+ num_heads: int = 1,
+ dropout: float = 0.1,
+ layer_idx: int | None = None,
+ ):
+ super().__init__()
+ self.q = nn.Linear(q_dim, att_dim)
+ self.k = nn.Linear(k_dim, att_dim)
+ self.ln_q = nn.LayerNorm(att_dim)
+ self.ln_k = nn.LayerNorm(att_dim)
+ self.dropout_att = nn.Dropout(dropout)
+ self.crossatt_type = crossatt_type
+ self.layer_idx = layer_idx
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ text_freqs: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ past_key_values=None,
+ cache=None,
+ **kwargs,
+ ):
+ q = self.ln_q(self.q(q))
+ k = self.ln_k(self.k(k))
+
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+
+ q, k = map(lambda x: rearrange(x, "b n d -> b 1 n d"), (q, k))
+ if self.training:
+ sdpa = lambda q, k, pos: (
+ nn.functional.scaled_dot_product_attention(
+ q, k, pos, attn_mask=mask, dropout_p=self.dropout_att.p
+ ),
+ None,
+ )
+ else:
+ sdpa = lambda q, k, pos: scaled_dot_product_attention(q, k, pos, mask=mask)
+
+
+ text_freqs = rearrange(text_freqs, "b n d -> b 1 n d")
+ if self.crossatt_type == "listen":
+ x, att = sdpa(q, k, text_freqs)
+ elif self.crossatt_type == "read":
+ x, att = sdpa(q, text_freqs, k)
+ else:
+ raise ValueError
+
+ x = rearrange(x, "b 1 n d -> b n d")
+ if att is not None:
+ if cache is not None:
+ cache.update(
+ crossatt_weights=att,
+ layer_idx=self.layer_idx,
+ )
+
+ self.att = att
+ return x
diff --git a/tts/layers/conv.py b/tts/layers/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..86ae2783a7b404dd56d2722a862a3ed549bdddc4
--- /dev/null
+++ b/tts/layers/conv.py
@@ -0,0 +1,56 @@
+import torch
+from torch import nn
+
+
+class ConvNeXtBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int | None = None,
+ layer_scale_init_value: float = 0.0,
+ elementwise_affine_ln: bool = True,
+ kernel_size: int = 5,
+ ):
+ super().__init__()
+ intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3
+ self.dwconv = nn.Conv1d(
+ dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim
+ ) # depthwise conv
+ self.norm = nn.LayerNorm(
+ dim, eps=1e-6, elementwise_affine=elementwise_affine_ln
+ )
+ self.pwconv1 = nn.Linear(
+ dim, intermediate_dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+ if layer_scale_init_value > 0
+ else None
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
+ gate: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ residual = x
+ x = self.dwconv(x)
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
+ x = self.norm(x)
+ if scale_shift is not None:
+ scale, shift = scale_shift
+ x = x * scale[:, None] + shift[:, None]
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ if gate is not None:
+ x = gate[:, None] * x
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
+
+ x = residual + x
+ return x
diff --git a/tts/layers/ffn.py b/tts/layers/ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5133855ba99e3b0be9257c3383edbceef49514fd
--- /dev/null
+++ b/tts/layers/ffn.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class SwiGLU(nn.Module):
+ def __init__(self, d_model: int, ffn_expansion_factor: int = 4):
+ super().__init__()
+ self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2)
+ self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model)
+
+ def forward(self, x):
+ gate, x = self.p_in(x).chunk(2, dim=-1)
+ return self.p_out(nn.functional.silu(gate) * x)
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+class GaussianFourierTimeEmbedding(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(dim), requires_grad=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x[:, None] * self.weight[None, :] * 2 * torch.pi
+ x = torch.cat((torch.sin(x), torch.cos(x)), dim=1)
+ return x
+
+
+class AdaLNFinalLayer(nn.Module):
+ def __init__(self, hidden_dim, feature_dim):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_dim, feature_dim, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_dim, 2 * hidden_dim, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class AdaLNMLP(nn.Module):
+ def __init__(self, hidden_dim):
+ super().__init__()
+ self.hidden_dim = hidden_dim
+
+ self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False)
+ self.mlp = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
+ )
+
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_dim, 3 * hidden_dim, bias=True)
+ )
+
+ def forward(self, x, y):
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
+ h = self.mlp(h)
+ return x + gate_mlp * h
diff --git a/tts/model/__init__.py b/tts/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..29c6330d7c704d41d9874637cd36ca123a8e51db
--- /dev/null
+++ b/tts/model/__init__.py
@@ -0,0 +1,2 @@
+from .simple_gla import SimpleGLADecoder
+from .transformer import TransformerDecoder, TransformerEncoder
diff --git a/tts/model/__pycache__/__init__.cpython-312.pyc b/tts/model/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a861e40669415a0110543fbf29f020337f36e4b2
Binary files /dev/null and b/tts/model/__pycache__/__init__.cpython-312.pyc differ
diff --git a/tts/model/__pycache__/cache_utils.cpython-312.pyc b/tts/model/__pycache__/cache_utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..679e727bc5bc04b093d39d1dc1d7bd7e83de130c
Binary files /dev/null and b/tts/model/__pycache__/cache_utils.cpython-312.pyc differ
diff --git a/tts/model/__pycache__/config.cpython-312.pyc b/tts/model/__pycache__/config.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58a395d83c439f80acfb01e4894c7cd9ca0b2218
Binary files /dev/null and b/tts/model/__pycache__/config.cpython-312.pyc differ
diff --git a/tts/model/__pycache__/prediction_head.cpython-312.pyc b/tts/model/__pycache__/prediction_head.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48399463364d0eb4310728ce97d40a4e7d53e442
Binary files /dev/null and b/tts/model/__pycache__/prediction_head.cpython-312.pyc differ
diff --git a/tts/model/__pycache__/registry.cpython-312.pyc b/tts/model/__pycache__/registry.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..15f02dc37ee8dc6466c66e0624bd59b1fce61f3f
Binary files /dev/null and b/tts/model/__pycache__/registry.cpython-312.pyc differ
diff --git a/tts/model/__pycache__/shortconv.cpython-312.pyc b/tts/model/__pycache__/shortconv.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5cb803d4c913d7e2f556eefd907b7bd3cee22d1d
Binary files /dev/null and b/tts/model/__pycache__/shortconv.cpython-312.pyc differ
diff --git a/tts/model/__pycache__/simple_gla.cpython-312.pyc b/tts/model/__pycache__/simple_gla.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8ffa4c6694d8dbeea14d22c5051e914de51f2453
Binary files /dev/null and b/tts/model/__pycache__/simple_gla.cpython-312.pyc differ
diff --git a/tts/model/__pycache__/transformer.cpython-312.pyc b/tts/model/__pycache__/transformer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05056791a82cda49b2c8a660ff452a3340b30430
Binary files /dev/null and b/tts/model/__pycache__/transformer.cpython-312.pyc differ
diff --git a/tts/model/cache.py b/tts/model/cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fd84372fab1fe59bf02c0a75c55482084558351
--- /dev/null
+++ b/tts/model/cache.py
@@ -0,0 +1,308 @@
+from typing import Any
+
+import torch
+from transformers.cache_utils import Cache, _static_cache_update
+
+
+class StaticCache(Cache):
+ """
+ Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
+
+ Parameters:
+ config (`PretrainedConfig`):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
+ smaller batch size is used. If you are manually setting the batch size, make sure to take into account the
+ number of beams if you are running beam search
+ max_cache_len (`int`, *optional*):
+ The maximum sequence length with which the model will be used.
+ device (`torch.device` or `str`, *optional*):
+ The device on which the cache should be initialized. If you're using more than 1 computation device, you
+ should pass the `layer_device_map` argument instead.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ The default `dtype` to use when initializing the layer.
+ layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
+ Mapping between the layers and its device. This is required when you are manually initializing the cache
+ and the model is split between different gpus. You can know which layers mapped to which device by
+ checking the associated device_map: `model.hf_device_map`.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+
+ >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ StaticCache()
+ ```
+ """
+
+ is_compileable = True
+
+ def __init__(
+ self,
+ max_batch_size: int,
+ head_dim: int,
+ num_key_value_heads: int,
+ num_hidden_layers: int,
+ max_cache_len: int | None = None,
+ device: torch.device | str | None = None,
+ dtype: torch.dtype = torch.float32,
+ layer_device_map: dict[int, str | torch.device | int] | None = None,
+ ) -> None:
+ super().__init__()
+ self.max_batch_size = max_batch_size
+ self.max_cache_len = max_cache_len
+
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+ self.head_dim = head_dim
+ self._dtype = dtype
+ self.num_key_value_heads = num_key_value_heads
+ self.num_hidden_layers = num_hidden_layers
+ self.key_cache: list[torch.Tensor] = []
+ self.value_cache: list[torch.Tensor] = []
+ # Note: There will be significant perf decrease if switching to use 5D tensors instead.
+ cache_shape = (
+ self.max_batch_size,
+ self.num_key_value_heads,
+ self.max_cache_len,
+ self.head_dim,
+ )
+ device = torch.device(device) if device is not None else None
+ for idx in range(self.num_hidden_layers):
+ if layer_device_map is not None:
+ layer_device = layer_device_map[idx]
+ else:
+ layer_device = device
+ new_layer_key_cache = torch.zeros(
+ cache_shape, dtype=self._dtype, device=layer_device
+ )
+ new_layer_value_cache = torch.zeros(
+ cache_shape, dtype=self._dtype, device=layer_device
+ )
+ # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
+ # preventing compiled graph breaks when updating the cache.
+ torch._dynamo.mark_static_address(new_layer_key_cache)
+ torch._dynamo.mark_static_address(new_layer_value_cache)
+ self.key_cache.append(new_layer_key_cache)
+ self.value_cache.append(new_layer_value_cache)
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: dict[str, Any] | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
+ to know how where to write in the cache.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ if cache_kwargs is None:
+ cache_kwargs = {}
+
+ key_states = key_states.to(self.key_cache[layer_idx].dtype)
+ value_states = value_states.to(self.value_cache[layer_idx].dtype)
+ return _static_cache_update(
+ self.key_cache[layer_idx],
+ self.value_cache[layer_idx],
+ key_states,
+ value_states,
+ cache_kwargs.get("cache_position"),
+ )
+
+ def get_seq_length(self, layer_idx: int | None = 0) -> int:
+ """Returns the sequence length of the cached states that were seen by the model."""
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+ # limit the check to the first batch member and head dimension.
+ # TODO: deprecate this function in favor of `cache_position`
+ return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
+
+ def get_max_cache_shape(self) -> int | None:
+ return self.max_cache_len
+
+ def reset(self):
+ """Resets the cache values while preserving the objects"""
+ for layer_idx in range(len(self.key_cache)):
+ # In-place ops prevent breaking the static address
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+ def get_mask_sizes(
+ self, cache_position: torch.Tensor, layer_idx: int
+ ) -> tuple[int, int]:
+ """
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
+ the given layer at `layer_idx`.
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
+ for each layer.
+ """
+ kv_length = self.get_max_cache_shape()
+ return kv_length, 0
+
+
+class Cache:
+ """
+ A cache used for storing hidden states produced by flash linear attention models.
+
+ It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
+ """
+
+ is_compileable = True
+
+ def __init__(self, seen_tokens: int = 0) -> Cache:
+ super().__init__()
+
+ self.states: list[dict[str, Any]] = []
+
+ self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
+
+ def __getitem__(self, layer_idx: int) -> dict[str, Any]:
+ if layer_idx < len(self):
+ return self.states[layer_idx]
+ else:
+ raise KeyError(
+ f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}"
+ )
+
+ def __iter__(self):
+ for state in self.states:
+ yield state
+
+ def __len__(self):
+ return len(self.states)
+
+ def update(
+ self,
+ recurrent_state: torch.Tensor | None = None,
+ attn_state: tuple[torch.Tensor, torch.Tensor] | None = None,
+ conv_state: tuple[torch.Tensor] | None = None,
+ ffn_state: torch.Tensor | None = None,
+ layer_idx: int = 0,
+ offset: int | None = 1,
+ cache_kwargs: dict | None = None,
+ ):
+ """
+ Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.
+
+ Args:
+ recurrent_state (`torch.Tensor`, `optional`):
+ The new recurrent state to cache.
+ attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
+ The new attention key/value states to cache.
+ conv_state (`Tuple[torch.Tensor]`, `optional`):
+ The new convolution state to cache.
+ layer_idx (`int`, defaults to 0):
+ The index of the layer to cache the states for.
+ offset (`int`, `optional`, defaults to 1):
+ The number of new tokens being processed.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass.
+
+ Return:
+ Dictionary of the updated state.
+ """
+
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += offset
+
+ if attn_state is not None:
+ input_size = attn_state[0].shape[-2]
+ window_size = cache_kwargs.get("window_size", None)
+ if not isinstance(attn_state, Tuple) or len(attn_state) != 2:
+ raise ValueError(
+ "`attn_state` must be a tuple of two tensors for key/value states"
+ )
+ if len(self.states) <= layer_idx:
+ if attn_state is not None:
+ if window_size is not None and input_size > window_size:
+ attn_state = (
+ attn_state[0][..., -window_size:, :].contiguous(),
+ attn_state[1][..., -window_size:, :].contiguous(),
+ )
+ state = dict(
+ recurrent_state=recurrent_state,
+ attn_state=attn_state,
+ conv_state=conv_state,
+ ffn_state=ffn_state,
+ )
+ self.states.append(state)
+ else:
+ state = self.states[layer_idx]
+ if recurrent_state is not None:
+ state["recurrent_state"] = recurrent_state
+ if attn_state is not None:
+ key_state, value_state = state["attn_state"]
+ if window_size is not None and key_state.shape[-2] == window_size:
+ # DO NOT allocate new memory if the cache is full
+ # roll the key/value states to the left by `input_size`
+ key_state = key_state.roll(-input_size, -2)
+ value_state = value_state.roll(-input_size, -2)
+ # replace the last `input_size` tokens with the new key/value states
+ key_state[..., -input_size:, :] = attn_state[0]
+ value_state[..., -input_size:, :] = attn_state[1]
+ attn_state = (key_state, value_state)
+ else:
+ attn_state = (
+ torch.cat([key_state, attn_state[0]], -2),
+ torch.cat([value_state, attn_state[1]], -2),
+ )
+ state["attn_state"] = attn_state
+ if conv_state is not None:
+ state["conv_state"] = conv_state
+ if ffn_state is not None:
+ state["ffn_state"] = ffn_state
+
+ return state
+
+ def get_seq_length(self, layer_idx: int | None = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ if len(self.states) <= layer_idx:
+ return 0
+ return self._seen_tokens
+
+ def get_max_length(self) -> int | None:
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
+ return None
+
+ def to_legacy_cache(self) -> tuple:
+ return tuple(self.states)
+
+ @classmethod
+ @torch.compiler.disable
+ def from_legacy_cache(
+ cls, past_key_values: tuple | None = None, seen_tokens: int = 0
+ ) -> Cache:
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
+
+ cache = cls(seen_tokens)
+ if isinstance(past_key_values, list):
+ for layer_idx in range(len(past_key_values)):
+ cache.states.append(past_key_values[layer_idx])
+ return cache
diff --git a/tts/model/cache_utils.py b/tts/model/cache_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0313c655c1fe839a49b4a1edd6ed0a3c3922c518
--- /dev/null
+++ b/tts/model/cache_utils.py
@@ -0,0 +1,208 @@
+from __future__ import annotations
+
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import transformers
+from transformers.cache_utils import DynamicCache
+
+from tts.layers.attention import precompute_freqs_cis
+
+
+class FLACache(transformers.cache_utils.Cache):
+ """
+ A cache used for storing hidden states produced by flash linear attention models.
+
+ It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
+ """
+
+ is_compileable = True
+
+ def __init__(
+ self,
+ head_dim: int | None = None,
+ max_seq_len: int | None = None,
+ num_states: int | None = None,
+ device: str = "cuda",
+ seen_tokens: int = 0,
+ ):
+ super().__init__()
+ if head_dim is not None and max_seq_len is not None:
+ self.freqs = precompute_freqs_cis(
+ torch.arange(max_seq_len, device=device), head_dim
+ )
+ self.states: List[Dict[str, Any]] = []
+ if num_states is not None:
+ self.states = [
+ dict(
+ recurrent_state=None,
+ attn_state=None,
+ conv_state=None,
+ short_conv_state=None,
+ ffn_state=None,
+ crossatt_state=None,
+ crossatt_weights=None,
+ )
+ for _ in range(num_states)
+ ]
+
+ self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
+
+ def __getitem__(self, layer_idx: int) -> Dict[str, Any]:
+ if layer_idx < len(self):
+ return self.states[layer_idx]
+ else:
+ raise KeyError(
+ f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}"
+ )
+
+ def __iter__(self):
+ for state in self.states:
+ yield state
+
+ def __len__(self):
+ return len(self.states)
+
+ def update(
+ self,
+ recurrent_state: torch.Tensor = None,
+ attn_state: Tuple[torch.Tensor, torch.Tensor] = None,
+ conv_state: Tuple[torch.Tensor] = None,
+ short_conv_state: Tuple[torch.Tensor] = None,
+ crossatt_state: Tuple[torch.Tensor] = None,
+ crossatt_weights: Tuple[torch.Tensor] = None,
+ ffn_state: torch.Tensor = None,
+ layer_idx: int = 0,
+ offset: Optional[int] = 1,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, Any]:
+ """
+ Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.
+
+ Args:
+ recurrent_state (`torch.Tensor`, `optional`):
+ The new recurrent state to cache.
+ attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
+ The new attention key/value states to cache.
+ conv_state (`Tuple[torch.Tensor]`, `optional`):
+ The new convolution state to cache.
+ layer_idx (`int`, defaults to 0):
+ The index of the layer to cache the states for.
+ offset (`int`, `optional`, defaults to 1):
+ The number of new tokens being processed.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass.
+
+ Return:
+ Dictionary of the updated state.
+ """
+
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += offset
+
+ if attn_state is not None:
+ input_size = attn_state[0].shape[-2]
+ window_size = (
+ cache_kwargs.get("window_size", None)
+ if cache_kwargs is not None
+ else None
+ )
+ if not isinstance(attn_state, Tuple) or len(attn_state) != 2:
+ raise ValueError(
+ "`attn_state` must be a tuple of two tensors for key/value states"
+ )
+ if len(self.states) <= layer_idx:
+ if attn_state is not None:
+ if window_size is not None and input_size > window_size:
+ attn_state = (
+ attn_state[0][..., -window_size:, :].contiguous(),
+ attn_state[1][..., -window_size:, :].contiguous(),
+ )
+ state = dict(
+ recurrent_state=recurrent_state,
+ attn_state=attn_state,
+ conv_state=conv_state,
+ short_conv_state=short_conv_state,
+ ffn_state=ffn_state,
+ crossatt_state=crossatt_state,
+ crossatt_weights=crossatt_weights,
+ )
+ self.states.append(state)
+ else:
+ state = self.states[layer_idx]
+ if recurrent_state is not None:
+ state["recurrent_state"] = recurrent_state
+ if crossatt_state is not None:
+ state["crossatt_state"] = crossatt_state
+ if crossatt_weights is not None:
+ if state["crossatt_weights"] is not None:
+ state["crossatt_weights"] = torch.cat(
+ (state["crossatt_weights"], crossatt_weights), dim=-2
+ )
+ else:
+ state["crossatt_weights"] = crossatt_weights
+ if attn_state is not None:
+ key_state, value_state = state["attn_state"]
+ if window_size is not None and key_state.shape[-2] == window_size:
+ # DO NOT allocate new memory if the cache is full
+ # roll the key/value states to the left by `input_size`
+ key_state = key_state.roll(-input_size, -2)
+ value_state = value_state.roll(-input_size, -2)
+ # replace the last `input_size` tokens with the new key/value states
+ key_state[..., -input_size:, :] = attn_state[0]
+ value_state[..., -input_size:, :] = attn_state[1]
+ attn_state = (key_state, value_state)
+ else:
+ attn_state = (
+ torch.cat([key_state, attn_state[0]], -2),
+ torch.cat([value_state, attn_state[1]], -2),
+ )
+ state["attn_state"] = attn_state
+ if conv_state is not None:
+ state["conv_state"] = conv_state
+ if short_conv_state is not None:
+ state["short_conv_state"] = short_conv_state
+ if ffn_state is not None:
+ state["ffn_state"] = ffn_state
+
+ return state
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ if len(self.states) <= layer_idx:
+ return 0
+ return self._seen_tokens
+
+ def get_max_length(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
+ return None
+
+ def to_legacy_cache(self) -> Tuple:
+ return tuple(self.states)
+
+ @classmethod
+ @torch.compiler.disable
+ def from_legacy_cache(
+ cls, past_key_values: Optional[Tuple] = None, seen_tokens: int = 0
+ ) -> Cache:
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
+
+ cache = cls(seen_tokens)
+ if isinstance(past_key_values, list):
+ for layer_idx in range(len(past_key_values)):
+ cache.states.append(past_key_values[layer_idx])
+ return cache
+
+
+class TransformerDecoderCache(FLACache):
+ def __init__(
+ self,
+ head_dim: int,
+ max_seq_len: int,
+ device: str,
+ ):
+ super().__init__()
+ self.freqs = precompute_freqs_cis(
+ torch.arange(max_seq_len, device=device), head_dim
+ )
diff --git a/tts/model/config.py b/tts/model/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..34270fccb445f78d94e3fe6a6e1583babbcb3981
--- /dev/null
+++ b/tts/model/config.py
@@ -0,0 +1,106 @@
+from dataclasses import dataclass, field
+from typing import Literal
+
+
+@dataclass
+class SimpleGLADecoderConfig:
+ name: str = "simple_gla"
+ dim: int = 512
+ use_short_conv = True
+ expand_k: float = 0.5
+ expand_v: float = 1.0
+ num_heads: int = 4
+ num_layers: int = 12
+ ffn_expansion_factor: int = 4
+ conv_layers: list[int] | None = field(default_factory=lambda: [0, 2, 4, 6, 8, 10])
+ blind_crossatt: bool = True
+ listen_read_crossatt: dict[int, list[Literal["listen", "read"]]] | None = None
+ crossatt_num_heads: int = 1
+ crossatt_dropout: float = 0.1
+ crossatt_layer_idx: list[int] = field(default_factory=lambda: [5])
+
+
+@dataclass
+class TransformerEncoderConfig:
+ name: str = "sa_transformer"
+ dim: int = 512
+ num_heads: int = 8
+ num_layers: int = 6
+ ffn_expansion_factor: int = 4
+
+
+@dataclass
+class ConvFormerEncoderConfig:
+ name: str = "convformer_encoder"
+ dim: int = 512
+ num_heads: int = 8
+ num_transformer_layers: int = 6
+ num_conv_layers: int = 3
+ ffn_expansion_factor: int = 4
+
+
+@dataclass
+class TransformerDecoderConfig:
+ name: str = "sa_transformer"
+ dim: int = 512
+ num_heads: int = 8
+ num_layers: int = 12
+ conv_layers: list[int] | None = field(default_factory=lambda: [0, 2, 4, 6, 8, 10])
+ ffn_expansion_factor: int = 4
+ crossatt_dropout: float = 0.1
+ crossatt_num_heads: int = 1
+ crossatt_layer_idx: list[int] = field(default_factory=lambda: [5, 6])
+
+
+DecoderConfig = TransformerDecoderConfig | SimpleGLADecoderConfig
+EncoderConfig = TransformerEncoderConfig | ConvFormerEncoderConfig
+
+
+@dataclass
+class TTSConfig:
+ dim: int = 512
+ text_vocab_size: int = 256 + 3
+ audio_vocab_size: int = 4096 + 3
+ audio_embed_size: int = 16
+ audio_input_type: Literal["continuous", "discrete"] = "continuous"
+ diffusion_head_num_layers: int = 3
+ encoder_cfg: EncoderConfig = field(default_factory=TransformerEncoderConfig)
+ decoder_cfg: DecoderConfig = field(default_factory=TransformerDecoderConfig)
+ stop_prediction_head: bool = True
+ multi_stop_prediction_head: bool = False
+ multi_stop_num_tokens: int = 8
+ multi_stop_padding_idx: int | None = None
+ multi_stop_tie_embd: bool = True
+ stop_token_embd: bool = False
+ text_stop_token_embd: bool = False
+ continuous_diffusion: bool = True
+ num_sink_tokens: int = 0
+ disabled_crossatt_head_idx: list[tuple[int, int]] | None = None
+ patchvae_path: str | None = None
+ text_tokenizer_path: str | None = None
+ eos: bool = False
+ bos: bool = False
+
+
+@dataclass
+class PlayHeadConfig:
+ selected_cross_attention_heads: list[tuple[int, int]]
+ dim: int = 256
+ num_layers: int = 6
+ num_frame_lag: int = 2
+ num_sink_tokens: int = 4
+ cycle_len: int = 8
+ logits_head: bool = True
+ target_lag: int = 0
+ avg_pool_stride: int = 3
+ circular_head: bool = False
+
+
+@dataclass
+class QueryVCConfig:
+ dim: int = 512
+ semantic_dim: int = 512
+ num_layers: int = 12
+ lag: int = 4
+ audio_embed_size: int = 8
+ diffusion_head_num_layers: int = 3
diff --git a/tts/model/prediction_head.py b/tts/model/prediction_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..f90a59fb4db633659210f8d34483f35410ac6d9d
--- /dev/null
+++ b/tts/model/prediction_head.py
@@ -0,0 +1,283 @@
+import os
+import sys
+from contextlib import contextmanager
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from torch import nn
+from torchdyn.core import NeuralODE
+
+from tts.layers.ffn import (AdaLNFinalLayer, AdaLNMLP,
+ GaussianFourierTimeEmbedding)
+
+
+@contextmanager
+def suppress_stdout():
+ original_stdout = sys.stdout
+ try:
+ sys.stdout = open(os.devnull, "w")
+ yield
+ finally:
+ sys.stdout.close()
+ sys.stdout = original_stdout
+
+
+def sample_from_logits(
+ logits: torch.Tensor,
+ temperature: float = 1.0,
+ top_k: int = 0,
+ top_p: float = 0.0,
+) -> torch.Tensor:
+ B, N, C = logits.shape
+ logits = logits / temperature
+
+ # Apply top-k
+ if top_k > 0:
+ top_k = min(top_k, C)
+ topk_values, _ = torch.topk(logits, top_k, dim=-1)
+ kth_value = topk_values[..., -1, None]
+ logits = torch.where(
+ logits < kth_value, torch.full_like(logits, float("-inf")), logits
+ )
+
+ # Apply top-p (nucleus) sampling
+ if top_p > 0.0 and top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+ probs = F.softmax(sorted_logits, dim=-1)
+ cumulative_probs = torch.cumsum(probs, dim=-1)
+
+ # Create mask for tokens to remove
+ cutoff_mask = cumulative_probs > top_p
+ cutoff_mask[..., 0] = 0 # Always keep at least one token
+ sorted_logits[cutoff_mask] = float("-inf")
+
+ # Map back to original logits shape
+ logits = torch.full_like(logits, float("-inf")).scatter(
+ -1, sorted_indices, sorted_logits
+ )
+
+ # Convert logits to probabilities
+ probs = F.softmax(logits, dim=-1)
+
+ # Sample
+ samples = torch.multinomial(probs.view(-1, C), num_samples=1).view(B, N)
+ return samples
+
+
+class LogitsHead(nn.Module):
+ def __init__(self, hidden_dim: int, vocab_size: int):
+ super().__init__()
+ self.logits_proj = nn.Linear(hidden_dim, vocab_size)
+
+ def forward(self, pre_logits):
+ return self.logits_proj(pre_logits)
+
+ def compute_loss(
+ self,
+ pre_logits: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ ):
+ logits = self(pre_logits)
+ if mask is not None:
+ flat_logits = logits[mask]
+ flat_target = target[mask]
+ else:
+ flat_logits = rearrange(logits, "b n l -> (b n) l")
+ flat_target = rearrange(target, "b n -> (b n)")
+
+ loss = nn.functional.cross_entropy(
+ flat_logits,
+ flat_target,
+ )
+
+ return {"cross_entropy": loss}
+
+ def predict(self, x: torch.Tensor, *args, **kwargs):
+ return sample_from_logits(self(x), *args, **kwargs)
+
+
+class ContinuousHead(nn.Module):
+ def __init__(self, hidden_dim: int, feature_dim: int):
+ super().__init__()
+ self.continuous_head = nn.Linear(hidden_dim, feature_dim)
+
+ def forward(self, x: torch.Tensor):
+ return self.continuous_head(x)
+
+ def predict(self, x: torch.Tensor):
+ return self(x)
+
+ def compute_loss(
+ self, pre_logits: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
+ ):
+ if mask is not None:
+ pre_logits = pre_logits[mask]
+ target = target[mask]
+ return {"mse": nn.functional.mse_loss(self(pre_logits), target)}
+
+
+class VelocityHead(nn.Module):
+ def __init__(
+ self,
+ hidden_dim: int,
+ feature_dim: int,
+ num_layers: int,
+ cond_dim: int | None = None,
+ ):
+ super().__init__()
+ cond_dim = cond_dim if cond_dim is not None else hidden_dim
+ self.feature_embed = nn.Linear(feature_dim, hidden_dim)
+ self.cond_embed = nn.Linear(cond_dim, hidden_dim)
+ self.time_embed = GaussianFourierTimeEmbedding(hidden_dim // 2)
+ self.adaln_mlp = nn.ModuleList(
+ [AdaLNMLP(hidden_dim) for _ in range(num_layers)]
+ )
+ self.adaln_final_layer = AdaLNFinalLayer(hidden_dim, feature_dim)
+ self.feature_dim = feature_dim
+
+ def forward(
+ self,
+ cond: torch.Tensor,
+ x: torch.Tensor,
+ t: torch.Tensor | None = None,
+ cond_drop_mask: torch.BoolTensor | None = None,
+ ):
+ cond = self.cond_embed(cond)
+ if cond_drop_mask is not None:
+ cond[cond_drop_mask] = 0.0
+ cond += self.time_embed(t)[:, None]
+ x = self.feature_embed(x)
+
+ for l in self.adaln_mlp:
+ x = l(x, cond)
+ y = self.adaln_final_layer(x, cond)
+
+ return y
+
+ def compute_loss(
+ self,
+ cond: torch.Tensor,
+ x1: torch.Tensor,
+ mask: torch.Tensor | None,
+ sigma: float = 1e-5,
+ t: torch.Tensor | None = None,
+ x0: torch.Tensor | None = None,
+ cfg_drop_rate: float = 0.1,
+ ):
+ """
+ CFM Loss
+ """
+ if t is None:
+ t = torch.rand(cond.shape[0], device=cond.device)
+
+ if x0 is None:
+ x0 = torch.randn_like(x1, device=x1.device)
+
+ flow_target = x1 - (1 - sigma) * x0
+ alpha = (1 - (1 - sigma) * t).view(-1, 1, 1)
+ xt = alpha * x0 + t.view(-1, 1, 1) * x1
+
+ if self.training and cfg_drop_rate > 0.0:
+ cond_drop_mask = torch.rand(cond.shape[:2]) < cfg_drop_rate
+ else:
+ cond_drop_mask = None
+
+ flow_pred = self(cond, xt, t, cond_drop_mask=cond_drop_mask)
+
+ if mask is not None:
+ flow_pred = flow_pred[mask]
+ flow_target = flow_target[mask]
+
+ loss = nn.functional.mse_loss(flow_pred, flow_target)
+
+ return {"diffusion": loss}
+
+ def predict(
+ self,
+ pre_prediction: torch.Tensor,
+ pre_prediction_ref: torch.Tensor | None = None,
+ solver: str = "euler",
+ sensitivity: str = "adjoint",
+ num_steps: int = 10,
+ cfg: float = 1.0,
+ cfg_ref: float = 1.5,
+ temperature: float = 1.0,
+ **kwargs,
+ ):
+ if cfg == 1.0:
+
+ if pre_prediction_ref is None:
+ def solver_fn(t, Xt, *args, **kwargs):
+ return self(pre_prediction, Xt, t.unsqueeze(0))
+ else:
+ raise NotImplementedError
+ else:
+ if pre_prediction_ref is None:
+
+ def solver_fn(t, Xt, *args, **kwargs):
+ cond_uncond = torch.cat(
+ (pre_prediction, torch.zeros_like(pre_prediction)),
+ dim=0,
+ )
+ cond_uncond = self(cond_uncond, Xt.repeat(2, 1, 1), t.unsqueeze(0))
+ cond, uncond = cond_uncond.chunk(2, dim=0)
+ cond_uncond_cfg = uncond + cfg * (cond - uncond)
+
+ return cond_uncond_cfg
+ else:
+
+ def solver_fn(t, Xt, *args, **kwargs):
+ cond_uncond_ref = torch.cat((pre_prediction, pre_prediction_ref, torch.zeros_like(pre_prediction)))
+ #cond_uncond_ref, = torch.cat(
+ # (pre_prediction, torch.zeros_like(pre_prediction), pre_prediction_ref),
+ # dim=0,
+ #)
+ cond_uncond = self(cond_uncond_ref, Xt.repeat(3, 1, 1), t.unsqueeze(0))
+ cond, ref, uncond = cond_uncond.chunk(3, dim=0)
+ #cond_uncond_cfg = uncond + cfg * (cond - uncond)
+ #cond_uncond_cfg_ref_cfg = ref + cfg_ref * (cond_uncond_cfg - ref)
+ cond_uncond_cfg = ref + cfg_ref * (cond - ref)
+ cond_uncond_cfg_ref_cfg = uncond + cfg * (cond_uncond_cfg - uncond)
+
+ return cond_uncond_cfg_ref_cfg
+
+
+ # get rid of torchdyn warning
+ with suppress_stdout():
+ node_ = NeuralODE(solver_fn, solver=solver, sensitivity=sensitivity)
+ t_span = torch.linspace(0, 1, num_steps + 1, device=pre_prediction.device)
+ traj = node_.trajectory(
+ torch.randn(pre_prediction.shape[0], 1, self.feature_dim, device=pre_prediction.device)
+ * temperature,
+ t_span=t_span,
+ )
+ prediction = traj[-1]
+ return prediction
+
+
+class StopPredictionHead(nn.Module):
+ def __init__(self, dim: int, weight_loss: float = 1.0):
+ super().__init__()
+ self.proj = nn.Linear(dim, 1)
+ self.weight_loss = weight_loss
+
+ def forward(self, pre_prediction: torch.Tensor):
+ return torch.sigmoid(self.proj(pre_prediction))
+
+ def predict(self, pre_prediction: torch.Tensor):
+ return torch.sigmoid(self.proj(pre_prediction))
+
+ def compute_loss(
+ self,
+ pre_prediction: torch.Tensor,
+ target: torch.Tensor,
+ ):
+ logits = self.proj(pre_prediction)
+ bce = nn.functional.binary_cross_entropy_with_logits(
+ logits.squeeze(-1),
+ target.to(logits.dtype),
+ weight=torch.ones(logits.shape[0], device=logits.device) * self.weight_loss,
+ )
+ return {"stop_bce": bce}
diff --git a/tts/model/registry.py b/tts/model/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a00c29717c2a7888a38cef182b1240068810762
--- /dev/null
+++ b/tts/model/registry.py
@@ -0,0 +1,18 @@
+ENCODER_REGISTRY = {}
+DECODER_REGISTRY = {}
+
+
+def register_encoder(name):
+ def wrapper(cls):
+ ENCODER_REGISTRY[name] = cls
+ return cls
+
+ return wrapper
+
+
+def register_decoder(name):
+ def wrapper(cls):
+ DECODER_REGISTRY[name] = cls
+ return cls
+
+ return wrapper
diff --git a/tts/model/shortconv.py b/tts/model/shortconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..4677c1bb58934a2f27aaccb01b55a1a645a09998
--- /dev/null
+++ b/tts/model/shortconv.py
@@ -0,0 +1,39 @@
+import torch
+from fla.modules import ShortConvolution
+from torch import nn
+
+from tts.layers.ffn import SwiGLU
+from tts.model.cache_utils import FLACache
+
+
+class ShortConvBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ kernel_size: int,
+ ffn_expansion_factor: int,
+ layer_idx: int,
+ use_fast_conv1d: bool = True,
+ ):
+ super().__init__()
+ self.tmix = ShortConvolution(dim, kernel_size, use_fast_conv1d=use_fast_conv1d)
+ self.cmix = SwiGLU(dim, ffn_expansion_factor)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.layer_idx = layer_idx
+
+ def forward(self, x: torch.Tensor, cache: FLACache | None = None, **kwargs):
+ last_state = None
+ if cache is not None and len(cache) > self.layer_idx:
+ last_state = cache[self.layer_idx]["short_conv_state"]
+ x, short_conv_state = self.tmix(
+ self.norm1(x), cache=last_state, output_final_state=cache is not None
+ )
+ if cache is not None:
+ cache.update(
+ short_conv_state=short_conv_state,
+ layer_idx=self.layer_idx,
+ offset=x.shape[1],
+ )
+ x = self.cmix(self.norm2(x))
+ return x
diff --git a/tts/model/simple_gla.py b/tts/model/simple_gla.py
new file mode 100644
index 0000000000000000000000000000000000000000..9425b7dd12ec50385ea99f086250b749f746a57b
--- /dev/null
+++ b/tts/model/simple_gla.py
@@ -0,0 +1,291 @@
+import os
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from fla.layers.simple_gla import SimpleGatedLinearAttention
+from fla.models.utils import Cache
+from sympy import num_digits
+from torch import nn
+
+from tts.layers.attention import CrossAttention
+from tts.layers.ffn import SwiGLU
+
+from .cache_utils import FLACache
+from .config import SimpleGLADecoderConfig
+from .registry import register_decoder
+from .shortconv import ShortConvBlock
+
+if "GRAD_CKPT" in os.environ:
+
+ def maybe_grad_ckpt(f):
+ def grad_ckpt_f(*args, **kwargs):
+ return torch.utils.checkpoint.checkpoint(
+ f, *args, **kwargs, use_reentrant=False
+ )
+
+ return grad_ckpt_f
+else:
+
+ def maybe_grad_ckpt(f):
+ return f
+
+
+class SimpleGLABlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ layer_idx: int,
+ expand_k: float,
+ expand_v: float,
+ use_short_conv: bool,
+ ffn_expansion_factor: int,
+ ):
+ super().__init__()
+ self.tmix = SimpleGatedLinearAttention(
+ hidden_size=dim,
+ num_heads=num_heads,
+ layer_idx=layer_idx,
+ )
+ self.cmix = SwiGLU(dim, ffn_expansion_factor)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ def forward(
+ self,
+ x,
+ freqs: torch.Tensor | None = None,
+ text_freqs: torch.Tensor | None = None,
+ cache: Cache | None = None,
+ ):
+ x = (
+ self.tmix(
+ self.norm1(x),
+ past_key_values=cache,
+ use_cache=cache is not None,
+ )[0]
+ + x
+ )
+ x = self.cmix(self.norm2(x)) + x
+ return x
+
+
+class DecoderBlockWithOptionalCrossAttention(nn.Module):
+ def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None):
+ super().__init__()
+
+ self.decoder_block = decoder_block
+ self.crossatt = crossatt
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ encoder_output: torch.Tensor | None = None,
+ freqs: torch.Tensor | None = None,
+ text_freqs: torch.Tensor | None = None,
+ cache: Cache | None = None,
+ selfatt_mask: torch.Tensor | None = None,
+ crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None,
+ ) -> torch.Tensor:
+ x = self.decoder_block(
+ x,
+ freqs=freqs,
+ cache=cache,
+ )
+ if type(crossatt_mask) is list:
+ crossatt_mask = crossatt_mask[self.decoder_block.tmix.layer_idx]
+ if self.crossatt is not None:
+ x = x + self.crossatt(
+ x,
+ k=encoder_output,
+ text_freqs=text_freqs,
+ mask=crossatt_mask,
+ cache=cache,
+ )
+
+ return x
+
+
+@register_decoder("simple_gla")
+class SimpleGLADecoder(nn.Module):
+ config = SimpleGLADecoderConfig
+
+ def __init__(self, cfg: SimpleGLADecoderConfig):
+ super().__init__()
+
+ assert cfg.dim % cfg.num_heads == 0, "num_heads should divide dim"
+ assert cfg.blind_crossatt + (cfg.listen_read_crossatt is not None) < 2, (
+ "at most one specialized cross-attention"
+ )
+
+ self.head_dim = cfg.dim // cfg.num_heads
+ self.num_heads = cfg.num_heads
+
+ def simple_gla_block(i):
+ conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers
+ if i in conv_layers:
+ return ShortConvBlock(
+ dim=cfg.dim,
+ kernel_size=4,
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
+ layer_idx=i,
+ use_fast_conv1d=True,
+ )
+
+ else:
+ return SimpleGLABlock(
+ dim=cfg.dim,
+ num_heads=cfg.num_heads,
+ layer_idx=i,
+ expand_k=cfg.expand_k,
+ expand_v=cfg.expand_v,
+ use_short_conv=cfg.use_short_conv,
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
+ )
+
+ def crossatt_block(i):
+ if i in cfg.crossatt_layer_idx:
+ return CrossAttention(
+ dim=cfg.dim,
+ num_heads=cfg.crossatt_num_heads,
+ dropout=cfg.crossatt_dropout,
+ layer_idx=i,
+ )
+ else:
+ return None
+
+ self.decoder_layers = nn.ModuleList(
+ [
+ DecoderBlockWithOptionalCrossAttention(
+ simple_gla_block(i),
+ crossatt_block(i),
+ )
+ for i in range(cfg.num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None,
+ text_ids: torch.Tensor | None = None,
+ cache: FLACache | None = None,
+ ):
+ x = decoder_input
+ text_freqs = None
+
+ for layer in self.decoder_layers:
+ x = maybe_grad_ckpt(layer)(
+ x,
+ encoder_output,
+ text_freqs=text_freqs,
+ cache=cache,
+ crossatt_mask=crossatt_mask,
+ )
+ return x
+
+ def init_cache(self, max_seq_len, device):
+ return FLACache(num_states=len(self.decoder_layers) + 1)
+
+ def init_initial_state(self, batch_size=1, scale=1e-2, device="cpu"):
+ return tuple(
+ nn.Parameter(
+ torch.randn(
+ batch_size,
+ self.num_heads,
+ self.head_dim,
+ self.head_dim,
+ device=device,
+ )
+ * scale
+ )
+ for _ in range(len(self.decoder_layers))
+ )
+ def init_initial_state_lora(self, lora:int=1, batch_size: int = 1, scale: float=1e-2, device: str="cpu"):
+ return tuple(
+ (
+ nn.Parameter(
+ torch.randn(
+ batch_size,
+ self.num_heads,
+ self.head_dim,
+ lora,
+ device=device,
+ )
+ * scale
+ ),
+ nn.Parameter(
+ torch.randn(
+ batch_size,
+ self.num_heads,
+ lora,
+ self.head_dim,
+ device=device,
+ )
+ * scale
+ )
+ )
+ for _ in range(len(self.decoder_layers))
+ )
+
+ def _get_query(self, audio_inputs: torch.Tensor, layer_idx: int):
+ assert self.decoder_layers[layer_idx].crossatt is not None
+ x = audio_inputs
+ for _, layer in zip(range(layer_idx - 1), self.decoder_layers):
+ x = layer(x, None)
+ return self.decoder_layers[layer_idx].crossatt._query(x)
+
+ def forward_first_n_layers(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ n_first_layers: int,
+ crossatt_mask: torch.Tensor | None = None,
+ cache: FLACache | None = None,
+ ):
+ x = decoder_input
+ if self.text_freqs_embd is not None:
+ text_freqs = torch.arange(encoder_output.shape[1], device=x.device)[None, :]
+ text_freqs = self.text_freqs_embd(text_freqs)
+ else:
+ text_freqs = None
+
+ for layer in self.decoder_layers[:n_first_layers]:
+ x = maybe_grad_ckpt(layer)(
+ x,
+ encoder_output,
+ text_freqs=text_freqs,
+ cache=cache,
+ crossatt_mask=crossatt_mask,
+ )
+ return x
+
+ def prefill(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ crossatt_mask: torch.Tensor | None = None,
+ cache: FLACache | None = None,
+ ):
+ return self(encoder_output, decoder_input, cache=cache, crossatt_mask=crossatt_mask)
+
+ def decode_one(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ cache: Cache,
+ text_freqs: torch.Tensor | None = None,
+ crossatt_mask: torch.Tensor | None = None,
+ ):
+ x = decoder_input
+ for layer in self.decoder_layers:
+ x = layer(
+ x,
+ encoder_output,
+ text_freqs=text_freqs,
+ cache=cache,
+ crossatt_mask=crossatt_mask,
+ )
+ return x
diff --git a/tts/model/transformer.py b/tts/model/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7ec1bc0e603a4845c89e2225a68ef264ce9b204
--- /dev/null
+++ b/tts/model/transformer.py
@@ -0,0 +1,282 @@
+import os
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from torch import nn
+from transformers.cache_utils import DynamicCache
+
+from tts.layers.attention import (CrossAttention, SelfAttention,
+ precompute_freqs_cis)
+from tts.layers.conv import ConvNeXtBlock
+from tts.layers.ffn import SwiGLU
+from tts.model.cache_utils import FLACache, TransformerDecoderCache
+from tts.model.config import (ConvFormerEncoderConfig,
+ TransformerDecoderConfig,
+ TransformerEncoderConfig)
+from tts.model.registry import register_decoder, register_encoder
+from tts.model.shortconv import ShortConvBlock
+
+if "GRAD_CKPT" in os.environ:
+
+ def maybe_grad_ckpt(f):
+ def grad_ckpt_f(*args, **kwargs):
+ return torch.utils.checkpoint.checkpoint(
+ f, *args, **kwargs, use_reentrant=False
+ )
+
+ return grad_ckpt_f
+else:
+
+ def maybe_grad_ckpt(f):
+ return f
+
+
+class TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ layer_idx: int,
+ ffn_expansion_factor: int,
+ is_causal: bool,
+ ):
+ super().__init__()
+ self.tmix = SelfAttention(dim, num_heads, layer_idx, is_causal=is_causal)
+ self.cmix = SwiGLU(dim, ffn_expansion_factor)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ def forward(
+ self,
+ x,
+ freqs: torch.Tensor | None = None,
+ cache: DynamicCache | None = None,
+ mask: torch.Tensor | None = None,
+ ):
+ x = self.tmix(self.norm1(x), freqs=freqs, cache=cache, mask=mask) + x
+ x = self.cmix(self.norm2(x)) + x
+ return x
+
+
+class DecoderBlockWithOptionalCrossAttention(nn.Module):
+ def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None):
+ super().__init__()
+
+ self.decoder_block = decoder_block
+ self.crossatt = crossatt
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ encoder_output: torch.Tensor | None = None,
+ freqs: torch.Tensor | None = None,
+ cache: FLACache | None = None,
+ selfatt_mask: torch.Tensor | None = None,
+ crossatt_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ x = self.decoder_block(
+ x,
+ freqs=freqs,
+ cache=cache,
+ )
+
+ if self.crossatt is not None:
+ x = x + self.crossatt(
+ x,
+ k=encoder_output,
+ mask=crossatt_mask,
+ cache=cache,
+ )
+
+ return x
+
+
+@register_decoder("sa_transformer")
+class TransformerDecoder(nn.Module):
+ config = TransformerDecoderConfig
+ def __init__(self, cfg: TransformerDecoderConfig):
+ super().__init__()
+
+ assert cfg.dim % cfg.num_heads == 0
+ self.head_dim = cfg.dim // cfg.num_heads
+
+ def transformer_block(i):
+ conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers
+ if i in conv_layers:
+ return ShortConvBlock(
+ dim=cfg.dim,
+ kernel_size=4,
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
+ layer_idx=i,
+ use_fast_conv1d=True,
+ )
+ else:
+ return TransformerBlock(
+ dim=cfg.dim,
+ num_heads=cfg.num_heads,
+ layer_idx=i,
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
+ is_causal=True,
+ )
+
+ def crossatt_block(i):
+ return CrossAttention(
+ dim=cfg.dim,
+ num_heads=cfg.crossatt_num_heads,
+ dropout=cfg.crossatt_dropout,
+ layer_idx=i,
+ )
+
+ self.decoder_layers = nn.ModuleList(
+ [
+ DecoderBlockWithOptionalCrossAttention(
+ transformer_block(i),
+ crossatt_block(i) if i in cfg.crossatt_layer_idx else None,
+ )
+ for i in range(cfg.num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ crossatt_mask: torch.Tensor | None = None,
+ cache: FLACache | None = None,
+ ):
+ x = decoder_input
+ positions = torch.arange(x.shape[1], device=x.device)
+ freqs = precompute_freqs_cis(positions, self.head_dim)
+ for layer in self.decoder_layers:
+ x = maybe_grad_ckpt(layer)(
+ x,
+ encoder_output,
+ freqs=freqs,
+ crossatt_mask=crossatt_mask,
+ cache=cache,
+ )
+ return x
+
+ def init_cache(self, max_seq_len, device):
+ return FLACache(head_dim=self.head_dim, max_seq_len=max_seq_len, device=device)
+
+ def prefill(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ cache: FLACache | None = None,
+ ):
+ return self(encoder_output, decoder_input, cache=cache)
+
+ def decode_one(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ cache: FLACache,
+ crossatt_mask: torch.Tensor | None = None,
+ ):
+ x = decoder_input
+ pos = cache._seen_tokens
+ freq = cache.freqs[:, [pos]]
+ for layer in self.decoder_layers:
+ x = layer(
+ x,
+ encoder_output,
+ freqs=freq,
+ cache=cache,
+ crossatt_mask=crossatt_mask,
+ )
+ return x
+
+
+@register_encoder("sa_transformer")
+class TransformerEncoder(nn.Module):
+ config = TransformerEncoderConfig
+
+ def __init__(self, cfg: TransformerEncoderConfig):
+ super().__init__()
+ assert cfg.dim % cfg.num_heads == 0
+ self.head_dim = cfg.dim // cfg.num_heads
+ self.encoder_layers = nn.ModuleList(
+ [
+ TransformerBlock(
+ cfg.dim,
+ cfg.num_heads,
+ i,
+ cfg.ffn_expansion_factor,
+ is_causal=False,
+ )
+ for i in range(cfg.num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ ):
+ positions = torch.arange(x.shape[1], device=x.device)
+ freqs = precompute_freqs_cis(positions, self.head_dim)
+ if mask is not None:
+ mask = rearrange(mask, "b n m -> b 1 n m")
+ mask = torch.logical_or(
+ mask,
+ rearrange(torch.eye(mask.shape[-1], device=x.device), "n m -> 1 1 n m"),
+ )
+
+ for layer in self.encoder_layers:
+ x = layer(x, freqs=freqs, mask=mask)
+ return x
+
+
+@register_encoder("convformer_encoder")
+class ConvFormerEncoder(nn.Module):
+ config = ConvFormerEncoderConfig
+
+ def __init__(self, cfg: ConvFormerEncoderConfig):
+ super().__init__()
+ assert cfg.dim % cfg.num_heads == 0
+ self.head_dim = cfg.dim // cfg.num_heads
+ self.conv_layers = nn.ModuleList(
+ [ConvNeXtBlock(cfg.dim) for _ in range(cfg.num_conv_layers)]
+ )
+ self.encoder_layers = nn.ModuleList(
+ [
+ TransformerBlock(
+ cfg.dim,
+ cfg.num_heads,
+ i,
+ cfg.ffn_expansion_factor,
+ is_causal=False,
+ )
+ for i in range(cfg.num_transformer_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ text_rel_pos: torch.Tensor | None = None,
+ ):
+ if text_rel_pos is None:
+ text_rel_pos = torch.arange(x.shape[1], device=x.device)
+ freqs = precompute_freqs_cis(text_rel_pos, self.head_dim)
+ else:
+ freqs = precompute_freqs_cis(text_rel_pos, self.head_dim).unsqueeze(1)
+
+ x = x.transpose(1, 2)
+ for layer in self.conv_layers:
+ x = layer(x)
+ x = x.transpose(1, 2)
+ if mask is not None:
+ mask = rearrange(mask, "b n m -> b 1 n m")
+ mask = torch.logical_or(
+ mask,
+ rearrange(torch.eye(mask.shape[-1], device=x.device), "n m -> 1 1 n m"),
+ )
+
+ for layer in self.encoder_layers:
+ x = layer(x, freqs=freqs, mask=mask)
+ return x
diff --git a/tts/playhead.py b/tts/playhead.py
new file mode 100644
index 0000000000000000000000000000000000000000..f716be24f61cd7215685c3fbc81a4b92e859723c
--- /dev/null
+++ b/tts/playhead.py
@@ -0,0 +1,103 @@
+import torch
+from torch import nn
+
+from model.cache_utils import FLACache
+from model.config import PlayHeadConfig
+from model.prediction_head import CircularHead, LogitsHead
+from model.simple_gla import SimpleGLABlock
+
+
+class PlayHead(nn.Module):
+ def __init__(self, cfg: PlayHeadConfig):
+ super().__init__()
+ self.cycle_len = cfg.cycle_len
+ self.num_sink_tokens = cfg.num_sink_tokens
+ self.pos_embedding = nn.Embedding(cfg.num_sink_tokens + cfg.cycle_len, cfg.dim)
+ self.avg_pool_stride = cfg.avg_pool_stride
+ self.net = nn.ModuleList(
+ [
+ SimpleGLABlock(
+ dim=cfg.dim,
+ num_heads=cfg.dim // 128,
+ layer_idx=i,
+ expand_k=0.5,
+ expand_v=1.0,
+ use_short_conv=True,
+ ffn_expansion_factor=4,
+ )
+ for i in range(cfg.num_layers)
+ ]
+ )
+
+ self.logits_head = (
+ LogitsHead(cfg.dim, cfg.cycle_len) if cfg.logits_head else None
+ )
+ self.circular_head = CircularHead(cfg.dim) if cfg.circular_head else None
+
+ def forward(
+ self,
+ cross_attention_weights: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ ):
+ B, T, A = cross_attention_weights.shape
+ # if self.cross_attention_reduction == "sum":
+ # cross_attention_weights = cross_attention_weights.sum(1)
+
+ device = cross_attention_weights.device
+ pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len
+ sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len
+ sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None])
+ x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd
+ for block in self.net:
+ x = block(x)
+
+ losses = dict()
+ if self.logits_head is not None:
+ losses |= self.logits_head.compute_loss(x, target.long(), mask=mask)
+ if self.circular_head is not None:
+ losses |= self.circular_head.compute_loss(x, target, mask=mask)
+
+ return losses
+
+ def init_cache(self):
+ return FLACache(num_states=len(self.net))
+
+ def predict(
+ self,
+ cross_attention_weights: torch.Tensor,
+ previous_position: torch.Tensor | None = None,
+ cache: FLACache | None = None,
+ ):
+ avg_pool_ca = torch.nn.functional.avg_pool1d(
+ cross_attention_weights[:, self.num_sink_tokens :].transpose(-1, -2),
+ self.avg_pool_stride,
+ stride=self.avg_pool_stride,
+ ceil_mode=True,
+ ).transpose(-1, -2)
+
+ sink_ca = cross_attention_weights[:, : self.num_sink_tokens]
+ cross_attention_weights = torch.cat((sink_ca, avg_pool_ca), dim=1)
+
+ B, T, A = cross_attention_weights.shape
+ device = cross_attention_weights.device
+ pos = torch.arange(T - self.num_sink_tokens).to(device) % self.cycle_len
+ sink = torch.arange(self.num_sink_tokens).to(device) + self.cycle_len
+ sink_and_pos_embd = self.pos_embedding(torch.cat((sink, pos))[None])
+ x = cross_attention_weights.transpose(-1, -2) @ sink_and_pos_embd
+ for block in self.net:
+ x = block(x, cache=cache)
+ if self.logits_head is not None:
+ logits = self.logits_head(x)
+ pred_position = torch.argmax(logits, -1)
+
+ if previous_position is not None:
+ current_angle, previous_angle = map(
+ lambda x: torch.exp(1j * 2 * torch.pi * x / self.cycle_len),
+ (pred_position, previous_position),
+ )
+ diff = current_angle / previous_angle
+ step = (diff.angle() / (2 * torch.pi / self.cycle_len)).round().long()
+ return pred_position, step
+
+ return pred_position
diff --git a/tts/text_processor.py b/tts/text_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..322852d11ec901a84e659b8e46c3bed81b4bd9cb
--- /dev/null
+++ b/tts/text_processor.py
@@ -0,0 +1,79 @@
+from transformers import PreTrainedTokenizerFast
+import re
+import unicodedata
+
+REPLACEMENTS = {
+ # whitespace
+ "\n": " ",
+ "\r": " ",
+ "\t": " ",
+ "\xa0": " ",
+ "\u2009": " ",
+ "\u202f": " ",
+ "\u200b": "",
+ # quotes
+ "‘": "'",
+ "’": "'",
+ "‚": "'",
+ "‛": "'",
+ "“": '"',
+ "”": '"',
+ "„": '"',
+ "«": '"',
+ "»": '"',
+ # dashes
+ "–": "-",
+ "—": "-",
+ "−": "-",
+ "-": "-",
+ # ellipsis
+ "…": "...",
+ # bullets & symbols
+ "•": ".",
+ "∙": ".",
+ "·": ".",
+ # currencies
+ "€": " euros",
+ "$": " dollars",
+ "£": " pounds",
+ "¥": " yen",
+ # misc
+ "°": " degrees",
+ "©": "",
+ "®": "",
+ "™": "",
+}
+
+
+def clean_text(text: str) -> str:
+ text = unicodedata.normalize("NFKC", text)
+ for src, tgt in REPLACEMENTS.items():
+ text = text.replace(src, tgt)
+ text = re.sub(r"\s+", " ", text.strip()) # collapse spaces
+ return text
+
+
+class BasicTextProcessor:
+ """
+ Basic text processor on top of a character level BPE model.
+ """
+
+ def __init__(self, tokenizer_file: str):
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
+
+ def normalize(self, text: str) -> str:
+ """Basic pre-normalization: whitespace cleanup, punctuation spacing, etc."""
+ text = clean_text(text)
+ text = text.strip()
+ return text
+
+ def __call__(self, text: str, **tokenizer_kwargs):
+ """Normalize then tokenize."""
+ text = self.normalize(text)
+ return self.tokenizer.encode(text, **tokenizer_kwargs)
+
+ def detokenize(self, token_ids):
+ """Optional: convert back to string."""
+ out = self.tokenizer.decode(token_ids, skip_special_tokens=False)
+ whitespace = "##[WHITESPACE]"
+ return out.replace(" ", whitespace).replace(" ", "").replace(whitespace, " ")
diff --git a/tts/tools.py b/tts/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0e9862e5456386902f0a64abe191483a282f27e
--- /dev/null
+++ b/tts/tools.py
@@ -0,0 +1,250 @@
+from itertools import accumulate
+from typing import Callable, List, Optional
+
+import torch
+import torch.nn.functional as F
+
+default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def widen_alignment(
+ alignment: torch.Tensor, width: int | tuple[int, int], axis: str = "S"
+) -> torch.Tensor:
+ """
+ Widen 1-bands along one axis of an alignment matrix.
+
+ Args:
+ alignment: (B, T, S) binary/bool/int tensor
+ width: int or (left, right) expansion
+ e.g. 2 -> expand ±2
+ (1,3) -> expand -1 on the left, +3 on the right
+ axis: "S" to widen horizontally (across S),
+ "T" to widen vertically (across T)
+
+ Returns:
+ (B, T, S) tensor with widened 1-bands along the chosen axis
+ """
+ assert axis in ("S", "T")
+ orig_dtype = alignment.dtype
+ dev = alignment.device
+
+ # normalize widths
+ if isinstance(width, int):
+ left, right = width, width
+ else:
+ left, right = width
+ ksize = left + right + 1
+ kernel = torch.ones(1, 1, ksize, device=dev)
+
+ if axis == "S":
+ # (B*T, 1, S)
+ x = alignment.view(-1, 1, alignment.size(-1)).float()
+ x = F.pad(x, (left, right)) # explicit asymmetric padding
+ y = F.conv1d(x, kernel)
+ y = (y > 0).view_as(alignment)
+
+ else: # axis == "T"
+ # (B*S, 1, T)
+ x = (
+ alignment.permute(0, 2, 1)
+ .contiguous()
+ .view(-1, 1, alignment.size(1))
+ .float()
+ )
+ x = F.pad(x, (left, right))
+ y = F.conv1d(x, kernel)
+ # Back to (B, T, S)
+ y = (
+ (y > 0)
+ .view(alignment.size(0), alignment.size(2), alignment.size(1))
+ .permute(0, 2, 1)
+ )
+
+ # Cast back to original dtype
+ if orig_dtype == torch.bool:
+ return y
+ elif orig_dtype.is_floating_point:
+ return y.to(orig_dtype)
+ else:
+ return y.to(orig_dtype)
+
+
+def collect_heads(cache, selected_heads):
+ return torch.stack(
+ [
+ cache[layer]["crossatt_weights"][:, [head], [-1]]
+ for layer, head in selected_heads
+ ],
+ dim=1,
+ )
+
+
+def expand(x, r):
+ b, n, d = x.shape
+ x = x.unsqueeze(-1).repeat(1, 1, 1, r).reshape(b, n, r * d)
+ return x
+
+
+def path_matrix(positions: torch.Tensor, num_positions: int = None) -> torch.Tensor:
+ if num_positions is None:
+ num_positions = positions.max().item() + 1
+ return F.one_hot(positions, num_classes=num_positions).to(torch.int)
+
+
+def pad_2d_sequence(seq, padding_value=0):
+ max_x, max_y = map(max, zip(*map(lambda x: x.shape, seq)))
+ pad = lambda x: torch.nn.functional.pad(
+ x,
+ (0, max_y - x.shape[1], 0, max_x - x.shape[0]),
+ value=padding_value,
+ )
+ return torch.stack([pad(x) for x in seq])
+
+
+def audio_to_text_partial_neighbor_mask(
+ xlen,
+ ylen,
+ *,
+ past_tokens: int = 0,
+ future_tokens: int = 0,
+ device=None,
+ dtype=torch.bool,
+):
+ """
+ Build an (audio_len, text_len) boolean mask where True = allowed to attend.
+ Each audio frame (group g) can attend:
+ - all tokens of text group g (aligned word),
+ - last `past_tokens` tokens of text group g-1 (previous word),
+ - first `future_tokens` tokens of text group g+1 (next word).
+
+ Args:
+ xlen (list[int]): token counts per text word (groups), e.g. [2,1,3]
+ ylen (list[int]): frame counts per audio word (aligned groups), e.g. [4,2,5]
+ past_tokens (int): allow up to this many tokens from end of previous word
+ future_tokens (int): allow up to this many tokens from start of next word
+ device: torch device
+ dtype: output dtype (bool by default)
+
+ Returns:
+ mask: (A, T) boolean tensor (A = sum(ylen), T = sum(xlen))
+ """
+ if len(xlen) != len(ylen):
+ raise ValueError(f"len(xlen)={len(xlen)} must equal len(ylen)={len(ylen)}")
+ if any(l <= 0 for l in xlen) or any(l <= 0 for l in ylen):
+ raise ValueError("All lengths must be positive.")
+ if past_tokens < 0 or future_tokens < 0:
+ raise ValueError("past_tokens and future_tokens must be >= 0.")
+
+ n = len(xlen)
+
+ # Text-side: group id per token and position within its group
+ x_groups = torch.arange(n, device=device).repeat_interleave(
+ torch.tensor(xlen, device=device)
+ ) # (T,)
+ pos_in_group = torch.cat([torch.arange(L, device=device) for L in xlen]) # (T,)
+ # tokens from the end (0 for last token, 1 for second-to-last, ...)
+ pos_from_end = torch.cat(
+ [torch.arange(L - 1, -1, -1, device=device) for L in xlen]
+ ) # (T,)
+
+ T = x_groups.numel()
+
+ # Audio-side: group id per frame
+ y_groups = torch.arange(n, device=device).repeat_interleave(
+ torch.tensor(ylen, device=device)
+ ) # (A,)
+ A = y_groups.numel()
+
+ # Broadcast to (A, T)
+ G_audio = y_groups[:, None] # (A, 1)
+ G_text = x_groups[None, :] # (1, T)
+
+ # Conditions:
+ # 1) aligned word: all tokens
+ aligned = G_text == G_audio
+
+ # 2) previous word: last `past_tokens` tokens only
+ if past_tokens > 0:
+ prev_group = G_text == (G_audio - 1)
+ prev_tail = pos_from_end[None, :] < past_tokens
+ prev_ok = prev_group & prev_tail
+ else:
+ prev_ok = torch.zeros((A, T), dtype=torch.bool, device=device)
+
+ # 3) next word: first `future_tokens` tokens only
+ if future_tokens > 0:
+ next_group = G_text == (G_audio + 1)
+ next_head = pos_in_group[None, :] < future_tokens
+ next_ok = next_group & next_head
+ else:
+ next_ok = torch.zeros((A, T), dtype=torch.bool, device=device)
+
+ mask = (aligned | prev_ok | next_ok).to(dtype=dtype)
+ return mask
+
+
+def packmask_2d(xlen: list[int], ylen: list[int], offset: int = 0) -> torch.Tensor:
+ _, ybound = map(lambda x: [0] + list(accumulate(x, int.__add__)), (xlen, ylen))
+ lb, hb = [], []
+
+ for n, l, h in zip(xlen, ybound[:-1], ybound[1:]):
+ lb += [l] * n
+ hb += [h] * n
+
+ lb, hb = map(torch.tensor, (lb, hb))
+ if offset:
+ lb -= offset
+ hb += offset
+
+ rge = torch.arange(ybound[-1])
+
+ lm = rge.unsqueeze(0) >= lb.unsqueeze(1)
+ hm = rge.unsqueeze(0) < hb.unsqueeze(1)
+
+ return lm * hm
+
+
+def topk_sampling(seq, k=1, temp=1.0):
+ topk = torch.topk(seq, k, dim=-1)
+ logits = seq / temp
+ mask = logits < topk.values[:, [-1]]
+ logits[mask] = -float("Inf")
+ probs = torch.softmax(logits, dim=-1)
+ return torch.multinomial(probs, num_samples=1)
+
+
+def delay_rvq(
+ code,
+ head_token: int = -2,
+ tail_token: int = -3,
+):
+ q, _ = code.shape
+ extension = torch.ones((q, q + 1)).tril() * head_token
+ extension += torch.ones((q + 1, q)).tril(diagonal=-1).T * tail_token
+ extension = torch.flip(extension, (1,))
+ extended_code = torch.cat((code, extension), axis=1)
+ for i in range(q):
+ extended_code[i, :] = torch.roll(extended_code[i, :], i + 1)
+
+ return extended_code.long()
+
+
+def undelay_rvq(extended_code):
+ q, _, n = extended_code.shape
+ out = []
+ for i in range(q):
+ out.append(torch.roll(extended_code[i], -(i + 1), dims=1))
+ out = torch.stack(out, dim=0)
+ return out[:, :, : -(q + 1)]
+
+
+def sequence_mask(lengths, max_len=None, **kwargs):
+ batch_size = lengths.shape[0]
+ device = lengths.device
+ if max_len is None:
+ max_len = torch.max(lengths).item()
+
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
+ mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
+
+ return mask
diff --git a/tts/train_playhead.py b/tts/train_playhead.py
new file mode 100644
index 0000000000000000000000000000000000000000..b89558b0558487583b8b048cf3f7b49a557a6665
--- /dev/null
+++ b/tts/train_playhead.py
@@ -0,0 +1,203 @@
+import math
+
+import hydra
+import pytorch_lightning as ptl
+import torch
+from omegaconf import DictConfig
+from super_monotonic_align import maximum_path
+from torch.optim.lr_scheduler import LambdaLR
+
+from model.config import PlayHeadConfig
+from playhead import PlayHead
+from train_tts import TrainARTTS
+
+
+def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
+ def lr_lambda(step):
+ if step < warmup_steps:
+ return step / max(1, warmup_steps)
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
+ return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
+
+ return lr_lambda
+
+
+def expand(x, r):
+ b, n, d = x.shape
+ return x.unsqueeze(2).repeat(1, 1, r, 1).reshape(b, r * n, d)
+
+
+class TrainPlayHead(ptl.LightningModule):
+ def __init__(
+ self,
+ tts_checkpoint_path: str,
+ playhead_config: PlayHeadConfig,
+ learning_rate: float = 5e-4,
+ end_learning_rate: float | None = None,
+ weight_decay: float = 0.1,
+ betas: tuple[float, float] = (0.9, 0.999),
+ n_warmup_steps: int = 500,
+ n_training_steps: int = 300000,
+ ):
+ super(TrainPlayHead, self).__init__()
+
+ cfg = playhead_config
+ self.learning_rate = learning_rate
+ self.weight_decay = weight_decay
+ self.betas = betas
+ self.n_warmup_steps = n_warmup_steps
+ self.n_training_steps = n_training_steps
+ self.selected_cross_attention_heads = cfg.selected_cross_attention_heads
+ self.avg_pool_stride = cfg.avg_pool_stride
+ self.target_lag = cfg.target_lag
+
+ self.save_hyperparameters()
+
+ self.model = PlayHead(playhead_config)
+ tts_lightning_module = TrainARTTS.load_from_checkpoint(tts_checkpoint_path)
+ self.tts_model = tts_lightning_module.model.eval()
+ for p in self.tts_model.parameters():
+ p.requires_grad = False
+
+ def on_train_epoch_start(self):
+ if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"):
+ self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
+
+ def save_model_weights_and_config(
+ self,
+ dir: str | None,
+ model_filename: str = "model.st",
+ config_filename: str = "config.json",
+ ):
+ # cfg = self.hparams.config
+ # Path(dir).mkdir(exist_ok=True)
+ # model_path = Path(dir) / model_filename
+ # save_file(self.model.state_dict(), model_path)
+ # with open(Path(dir) / config_filename, "w") as f:
+ # json.dump(asdict(cfg), f, indent=2)
+ pass
+
+ def step(self, batch, batch_idx: int, validation: bool = False):
+ text_token = batch["text_token"]
+ audio_token = batch["audio_token"].squeeze(2)
+ crossatt_mask = batch["crossatt_mask"]
+ text_rel_pos = batch["text_rel_pos"]
+ encoder_mask = batch["encoder_mask"]
+ stop_token = batch.get("stop_token")
+ text_stop_token = batch.get("text_stop_token")
+ crossatt_rel_pos = batch.get("crossatt_rel_pos")
+ logits_mask = batch["y_mask"]
+
+ with torch.inference_mode():
+ _ = self.tts_model(
+ text_ids=text_token,
+ audio_inputs=audio_token,
+ text_mask=encoder_mask,
+ audio_mask=logits_mask,
+ crossatt_mask=crossatt_mask,
+ crossatt_rel_pos=crossatt_rel_pos,
+ stop_tokens=stop_token,
+ text_rel_pos=text_rel_pos,
+ text_stop_tokens=text_stop_token,
+ )
+
+ atts = []
+
+ for l in self.tts_model.audio_decoder.decoder_layers:
+ if l.crossatt is not None:
+ atts.append(l.crossatt.att)
+
+ num_sinks = self.tts_model.num_sink_tokens
+ selected_ca_heads = torch.stack(
+ [
+ atts[i][:, j].transpose(-1, -2)
+ for i, j in self.selected_cross_attention_heads
+ ]
+ )
+
+ summed_ca = selected_ca_heads.sum(0)
+
+ avg_pool_ca = torch.nn.functional.avg_pool1d(
+ summed_ca[:, num_sinks:].transpose(-1, -2),
+ self.avg_pool_stride,
+ stride=self.avg_pool_stride,
+ ceil_mode=True,
+ ).transpose(-1, -2)
+
+ mas_from_avg_pool = maximum_path(
+ avg_pool_ca.clone(),
+ mask=crossatt_mask[:, :-1, :: self.avg_pool_stride].transpose(-1, -2),
+ )
+ target = torch.arange(mas_from_avg_pool.shape[1]).to(mas_from_avg_pool.device)
+ if self.target_lag > 0:
+ lag = self.target_lag
+ mas_from_avg_pool = torch.roll(mas_from_avg_pool, lag, dims=2)
+ mas_from_avg_pool[:, 0, :lag] = 1.0
+ mas_from_avg_pool[:, 1:, :lag] = 0.0
+ # logits_mask[:, :lag] = False
+ target = (mas_from_avg_pool * target[:, None]).max(dim=1).values
+
+ sink_ca = summed_ca[:, :num_sinks]
+
+ input_ca = torch.cat((sink_ca, avg_pool_ca), dim=1)
+ target = target % self.model.cycle_len
+
+ return self.model(input_ca, target, logits_mask[:, :-1]), input_ca, target
+
+ def training_step(self, batch, idx):
+ losses, _, _ = self.step(batch, idx)
+ total_loss = 0.0
+ for name, loss in losses.items():
+ self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True)
+ total_loss += loss
+ self.log("train_loss", total_loss, prog_bar=True, sync_dist=True)
+ return total_loss
+
+ def validation_step(self, batch, idx):
+ losses, _, _ = self.step(batch, idx)
+ total_loss = 0.0
+ for name, loss in losses.items():
+ self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True)
+ total_loss += loss
+ self.log("val_loss", total_loss, prog_bar=True, sync_dist=True)
+ return total_loss
+
+ def configure_optimizers(self):
+ params = [
+ {
+ "params": self.model.parameters(),
+ "weight_decay": self.weight_decay,
+ }
+ ]
+ opt = torch.optim.AdamW(
+ params,
+ lr=self.learning_rate,
+ betas=self.betas,
+ )
+ scheduler = LambdaLR(
+ opt,
+ lr_lambda=cosine_schedule_with_warmup(
+ warmup_steps=self.hparams.n_warmup_steps,
+ total_steps=self.hparams.n_training_steps,
+ start_lr=self.hparams.learning_rate,
+ end_lr=self.hparams.learning_rate * 0.1,
+ ),
+ )
+ return [opt], [{"scheduler": scheduler, "interval": "step"}]
+
+
+@hydra.main(config_path="playhead_configs/", config_name="config", version_base="1.3")
+def main(cfg: DictConfig):
+ ptl.seed_everything(cfg.seed_everything)
+
+ model = hydra.utils.instantiate(cfg.model)
+ cfg.experiment_name = f"PlayHead"
+ datamodule = hydra.utils.instantiate(cfg.data)
+ trainer = hydra.utils.instantiate(cfg.trainer)
+
+ trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tts/train_tts.py b/tts/train_tts.py
new file mode 100644
index 0000000000000000000000000000000000000000..76fa749b0c257d52ee5521adaa84cfbf0a92bad0
--- /dev/null
+++ b/tts/train_tts.py
@@ -0,0 +1,227 @@
+import json
+import math
+from dataclasses import asdict
+from pathlib import Path
+
+import hydra
+import numpy as np
+import pytorch_lightning as ptl
+import torch
+from omegaconf import DictConfig, ListConfig, OmegaConf
+from safetensors.torch import save_file
+from torch import nn
+from torch.optim.lr_scheduler import LambdaLR
+from transformers import get_cosine_schedule_with_warmup
+
+from .model.config import TTSConfig
+from .model.prediction_head import VelocityHead
+from .tts import ARTTSModel
+
+
+def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
+ def lr_lambda(step):
+ if step < warmup_steps:
+ return step / max(1, warmup_steps)
+ progress = min((step - warmup_steps) / max(1, total_steps - warmup_steps), 1)
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
+ return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
+
+ return lr_lambda
+
+
+class TrainARTTS(ptl.LightningModule):
+ def __init__(
+ self,
+ config: TTSConfig,
+ quant_layer: list[int],
+ tie_embed: bool = False,
+ learning_rate: float = 5e-4,
+ end_learning_rate: float | None = None,
+ weight_decay: float = 0.1,
+ betas: tuple[float, float] = (0.9, 0.999),
+ n_warmup_steps: int = 500,
+ n_training_steps: int = 300000,
+ mask_text_p: float = 0.0,
+ load_weights: str | None = None,
+ stop_token_weight: float | None = None,
+ stop_loss_factor: float = 0.1,
+ stop_loss_warmup: tuple[int, int] | None = None,
+ ):
+ super(TrainARTTS, self).__init__()
+
+ self.learning_rate = learning_rate
+ self.weight_decay = weight_decay
+ self.betas = betas
+ self.n_warmup_steps = n_warmup_steps
+ self.n_training_steps = n_training_steps
+ self.stop_token_weight = stop_token_weight
+ self.stop_loss_factor = stop_loss_factor
+
+ self.save_hyperparameters()
+
+ self.model = ARTTSModel(config)
+
+ if load_weights is not None:
+ model = torch.load(load_weights)
+ self.load_state_dict(model["state_dict"], strict=False)
+
+ def on_train_epoch_start(self):
+ if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"):
+ self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
+
+ def save_model_weights_and_config(
+ self,
+ dir: str | None,
+ model_filename: str = "model.st",
+ config_filename: str = "config.json",
+ ):
+ def to_builtin(obj):
+ if isinstance(obj, dict):
+ return {k: to_builtin(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [to_builtin(v) for v in obj]
+ elif isinstance(obj, ListConfig):
+ return [to_builtin(v) for v in obj]
+ elif isinstance(obj, DictConfig):
+ return {k: to_builtin(v) for k, v in obj.items()}
+ else:
+ return obj
+
+ cfg = asdict(self.hparams.config)
+ cfg = to_builtin(cfg)
+ for k, v in cfg.items():
+ if v is ListConfig:
+ print("here")
+ cfg[k] = OmegaConf.to_container(v, resolve=True)
+ Path(dir).mkdir(exist_ok=True)
+ model_path = Path(dir) / model_filename
+ save_file(self.model.state_dict(), model_path)
+ with open(Path(dir) / config_filename, "w") as f:
+ json.dump(cfg, f, indent=2)
+
+ def step(self, batch, batch_idx: int, validation: bool = False):
+ text_token = batch["text_token"]
+ audio_token = batch["audio_token"].squeeze(2)
+ crossatt_mask = batch.get("crossatt_mask")
+ text_rel_pos = batch.get("text_rel_pos")
+ encoder_mask = batch.get("encoder_mask")
+ stop_token = batch.get("stop_token")
+ text_stop_token = batch.get("text_stop_token")
+ crossatt_rel_pos = batch.get("crossatt_rel_pos")
+ logits_mask = batch.get("y_mask")
+
+ pre_logits = self.model(
+ text_ids=text_token,
+ audio_inputs=audio_token,
+ text_mask=encoder_mask,
+ audio_mask=logits_mask,
+ crossatt_mask=crossatt_mask,
+ crossatt_rel_pos=crossatt_rel_pos,
+ stop_tokens=stop_token,
+ text_rel_pos=text_rel_pos,
+ text_stop_tokens=text_stop_token,
+ )
+ losses = {}
+ if validation and type(self.model.prediction_head) is DiffusionHead:
+ # deterministic time conditioning during validation
+ t = (
+ torch.ones(pre_logits.shape[0], device=pre_logits.device)
+ * batch_idx
+ / self.trainer.num_val_batches[0]
+ )
+ losses |= self.model.prediction_head.compute_loss(
+ pre_logits,
+ audio_token[:, 1:],
+ mask=logits_mask[:, 1:] if logits_mask is not None else None,
+ t=t,
+ )
+ else:
+ losses |= self.model.prediction_head.compute_loss(
+ pre_logits,
+ audio_token[:, 1:],
+ mask=logits_mask[:, 1:] if logits_mask is not None else None,
+ )
+
+ if self.model.stop_prediction_head is not None and logits_mask is not None:
+ if stop_token is None:
+ stop_token = nn.functional.pad(
+ (~logits_mask)[:, 2:].to(pre_logits), (0, 1)
+ )
+ else:
+ stop_token = stop_token[:, 1:]
+ mask = logits_mask[:, 1:]
+ losses |= self.model.stop_prediction_head.compute_loss(
+ pre_logits[mask],
+ stop_token[mask],
+ )
+
+ return losses
+
+ def training_step(self, batch, idx):
+ losses = self.step(batch, idx)
+ total_loss = 0.0
+ for name, loss in losses.items():
+ self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True)
+ if "stop" in name:
+ if self.hparams.stop_loss_warmup is not None:
+ alpha, beta = self.hparams.stop_loss_warmup
+ warmup = np.clip((idx - alpha) / beta, a_min=0.0, a_max=1.0)
+ else:
+ warmup = 1.0
+ loss *= self.stop_loss_factor * warmup
+ total_loss += loss
+ self.log("train_loss", total_loss, prog_bar=True, sync_dist=True)
+ return total_loss
+
+ def validation_step(self, batch, idx):
+ losses = self.step(batch, idx, validation=True)
+ total_loss = 0.0
+ for name, loss in losses.items():
+ self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True)
+ total_loss += loss
+ self.log("val_loss", total_loss, prog_bar=True, sync_dist=True)
+ return total_loss
+
+ def configure_optimizers(self):
+ params = [
+ {
+ "params": self.model.parameters(),
+ "weight_decay": self.weight_decay,
+ }
+ ]
+ opt = torch.optim.AdamW(
+ params,
+ lr=self.learning_rate,
+ betas=self.betas,
+ )
+ # scheduler = get_cosine_schedule_with_warmup(
+ # opt,
+ # num_warmup_steps=self.n_warmup_steps,
+ # num_training_steps=self.n_training_steps,
+ # )
+ scheduler = LambdaLR(
+ opt,
+ lr_lambda=cosine_schedule_with_warmup(
+ warmup_steps=self.hparams.n_warmup_steps,
+ total_steps=self.hparams.n_training_steps,
+ start_lr=self.hparams.learning_rate,
+ end_lr=self.hparams.learning_rate * 0.1,
+ ),
+ )
+ return [opt], [{"scheduler": scheduler, "interval": "step"}]
+
+
+@hydra.main(config_path="hydra_configs/", config_name="config", version_base="1.3")
+def main(cfg: DictConfig):
+ ptl.seed_everything(cfg.seed_everything)
+
+ model = hydra.utils.instantiate(cfg.model)
+ cfg.experiment_name = f"ARTTS_{model.hparams.config.decoder_cfg.name}"
+ datamodule = hydra.utils.instantiate(cfg.data)
+ trainer = hydra.utils.instantiate(cfg.trainer)
+
+ trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tts/tts.py b/tts/tts.py
new file mode 100644
index 0000000000000000000000000000000000000000..b082b523ad655f465f227d6b8d4162ccfcca7366
--- /dev/null
+++ b/tts/tts.py
@@ -0,0 +1,497 @@
+import json
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+from safetensors.torch import load_file
+from torch import nn
+from tqdm import tqdm
+
+from tts.model.config import TTSConfig
+from tts.model.prediction_head import (ContinuousHead, LogitsHead,
+ StopPredictionHead, VelocityHead)
+from tts.model.registry import DECODER_REGISTRY, ENCODER_REGISTRY
+from tts.tools import path_matrix, widen_alignment
+
+
+def collect_heads(cache, selected_heads, last=True):
+ if last:
+ return torch.cat(
+ [
+ cache[layer]["crossatt_weights"][:, [head], -1]
+ for layer, head in selected_heads
+ ],
+ dim=1,
+ )[:, :, None]
+ else:
+ return torch.cat(
+ [
+ cache[layer]["crossatt_weights"][:, [head]]
+ for layer, head in selected_heads
+ ],
+ dim=1,
+ )
+
+
+def mask_from_abs_pos(abs_pos, text_len, expand_factor, width=(5, 1)):
+ exp_ca_mask = path_matrix(abs_pos, text_len)
+ exp_ca_mask = widen_alignment(exp_ca_mask, width=width, axis="S")
+ exp_ca_mask = expand(exp_ca_mask, expand_factor)
+ return exp_ca_mask
+
+
+def expand(x, r):
+ b, n, d = x.shape
+ x = x.unsqueeze(-1).repeat(1, 1, 1, r).reshape(b, n, r * d)
+ return x
+
+
+class ARTTSModel(nn.Module):
+ def __init__(self, cfg: TTSConfig):
+ super().__init__()
+ self.text_embd = nn.Embedding(cfg.text_vocab_size, cfg.dim)
+ if cfg.audio_input_type == "discrete":
+ self.audio_embd = nn.Embedding(cfg.audio_vocab_size, cfg.dim)
+ self.prediction_head = LogitsHead(cfg.decoder_cfg.dim, cfg.audio_vocab_size)
+ elif cfg.audio_input_type == "continuous" and cfg.continuous_diffusion:
+ self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim)
+ self.prediction_head = VelocityHead(
+ cfg.decoder_cfg.dim,
+ cfg.audio_embed_size,
+ cfg.diffusion_head_num_layers,
+ )
+ elif cfg.audio_input_type == "continuous":
+ self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim)
+ self.prediction_head = ContinuousHead(
+ cfg.decoder_cfg.dim,
+ cfg.audio_embed_size,
+ )
+
+ self.text_encoder = ENCODER_REGISTRY[cfg.encoder_cfg.name](cfg.encoder_cfg)
+ self.audio_decoder = DECODER_REGISTRY[cfg.decoder_cfg.name](cfg.decoder_cfg)
+
+ self.stop_token_embd = None
+ self.stop_prediction_head = None
+
+ if cfg.stop_prediction_head:
+ if cfg.stop_token_embd:
+ self.stop_token_embd = nn.Embedding(2, cfg.dim, padding_idx=0)
+ self.stop_prediction_head = StopPredictionHead(cfg.dim)
+
+ if cfg.num_sink_tokens > 0:
+ self.sink_tokens = nn.Parameter(
+ torch.randn(cfg.num_sink_tokens, cfg.dim) * 0.02, requires_grad=True
+ )
+ else:
+ self.sink_tokens = None
+
+ self.disabled_crossatt_head_idx = cfg.disabled_crossatt_head_idx
+
+ @property
+ def num_sink_tokens(self):
+ if self.sink_tokens is None:
+ return 0
+ else:
+ n_sink, _ = self.sink_tokens.shape
+ return n_sink
+
+ @classmethod
+ def instantiate_from_config(cls, config):
+ for k in config.keys():
+ if k == "decoder_cfg":
+ config[k] = DECODER_REGISTRY[config[k]["name"]].config(**config[k])
+ if k == "encoder_cfg":
+ config[k] = ENCODER_REGISTRY[config[k]["name"]].config(**config[k])
+ config = TTSConfig(**config)
+ return ARTTSModel(config), config
+
+ @classmethod
+ def from_pretrained_local(
+ cls,
+ path: str,
+ config_filename: str = "config.json",
+ model_filename: str = "model.st",
+ device: str = "cpu",
+ ):
+ with open(Path(path) / config_filename, "r") as f:
+ config = json.load(f)
+ model, config = cls.instantiate_from_config(config)
+ state_dict = load_file(Path(path) / model_filename, device=device)
+ model.load_state_dict(state_dict)
+ return model
+
+ def _get_query(self, x: torch.Tensor, *args):
+ input_audio_embd = self.audio_embd(x)
+ return self.audio_decoder._get_query(input_audio_embd, *args)
+
+ def forward(
+ self,
+ text_ids: torch.LongTensor,
+ audio_inputs: torch.Tensor,
+ text_mask: torch.Tensor | None = None,
+ audio_mask: torch.Tensor | None = None,
+ stop_tokens: torch.Tensor | None = None,
+ text_stop_tokens: torch.Tensor | None = None,
+ text_rel_pos: torch.Tensor | None = None,
+ crossatt_mask: torch.Tensor | None = None,
+ crossatt_rel_pos: torch.Tensor | None = None,
+ n_first_layers: int | None = None,
+ cache: Any | None = None,
+ ):
+ input_text_embd = self.text_embd(text_ids)
+ input_audio_embd = self.audio_embd(audio_inputs[:, :-1])
+ if self.stop_token_embd is not None:
+ if stop_tokens is not None:
+ stop_tokens_embd = self.stop_token_embd(stop_tokens)
+ input_audio_embd += stop_tokens_embd[:, :-1]
+
+ text_hidden_states = self.text_encoder(
+ input_text_embd,
+ mask=text_mask,
+ text_rel_pos=text_rel_pos,
+ )
+ if self.disabled_crossatt_head_idx is not None and crossatt_mask is not None:
+ crossatt_mask_list = []
+ n_sink, _ = self.sink_tokens.shape
+ for layer in self.audio_decoder.decoder_layers:
+ if layer.crossatt is not None:
+ h = layer.crossatt.heads
+ crossatt_layer_mask = (
+ crossatt_mask.unsqueeze(1).repeat(1, h, 1, 1).clone()
+ )
+ crossatt_layer_mask = torch.nn.functional.pad(
+ crossatt_layer_mask,
+ (n_sink, 0),
+ value=True,
+ )
+ crossatt_mask_list.append(crossatt_layer_mask[:, :, :-1])
+
+ else:
+ crossatt_mask_list.append(None)
+
+ for layer, head in self.disabled_crossatt_head_idx:
+ crossatt_mask_list[layer][:, head, :, n_sink:] = False
+
+ crossatt_mask = crossatt_mask_list
+ else:
+ if self.sink_tokens is not None:
+ n_sink, _ = self.sink_tokens.shape
+ if crossatt_mask is not None:
+ crossatt_mask = torch.nn.functional.pad(
+ crossatt_mask,
+ (n_sink, 0),
+ value=True,
+ )
+ crossatt_mask = crossatt_mask[:, :-1]
+ if self.sink_tokens is not None:
+ sink_tokens = self.sink_tokens[None, :].repeat(
+ text_hidden_states.shape[0], 1, 1
+ )
+ text_hidden_states = torch.cat(
+ (sink_tokens, text_hidden_states),
+ dim=1,
+ )
+
+ if n_first_layers is not None:
+ pre_logits = self.audio_decoder.forward_first_n_layers(
+ text_hidden_states,
+ input_audio_embd,
+ n_first_layers,
+ crossatt_mask=crossatt_mask,
+ )
+ else:
+ pre_logits = self.audio_decoder(
+ text_hidden_states,
+ input_audio_embd,
+ crossatt_mask=crossatt_mask,
+ cache=cache,
+ )
+ return pre_logits
+
+ def generate(
+ self,
+ text_ids: torch.LongTensor,
+ prefix: torch.Tensor,
+ text_mask: torch.Tensor | None = None,
+ crossatt_mask: torch.Tensor | None = None,
+ text_rel_pos: torch.LongTensor | None = None,
+ teacher_force: torch.Tensor | None = None,
+ unfold_ref: bool = False,
+ max_seq_len: int = 200,
+ device: str = "cuda",
+ sampling_params: dict | None = None,
+ stop_threshold: float = 0.5,
+ cache: Any | None = None,
+ do_not_stop: bool = False,
+ ):
+ if sampling_params is None:
+ sampling_params = {}
+ if text_ids.ndim == 1:
+ text_ids = text_ids.unsqueeze(0)
+
+ batch_size = text_ids.shape[0]
+
+ input_text_embd = self.text_embd(text_ids)
+ text_hidden_states = self.text_encoder(
+ input_text_embd,
+ text_rel_pos=text_rel_pos,
+ mask=text_mask,
+ )
+ prefix_embd = self.audio_embd(prefix)
+
+ if self.sink_tokens is not None:
+ sink_tokens = self.sink_tokens[None, :].repeat(
+ text_hidden_states.shape[0], 1, 1
+ )
+ text_hidden_states = torch.cat(
+ (sink_tokens, text_hidden_states),
+ dim=1,
+ )
+ if crossatt_mask is not None:
+ n_sink, _ = self.sink_tokens.shape
+ crossatt_mask = torch.nn.functional.pad(
+ crossatt_mask,
+ (n_sink, 0),
+ value=True,
+ )
+
+ if cache is None:
+ cache = self.audio_decoder.init_cache(
+ max_seq_len + prefix_embd.shape[1], device
+ )
+
+ stop_status = torch.zeros(batch_size, device=device).bool()
+ stop_idx = torch.ones(batch_size, device=device).long()*max_seq_len
+
+ preds = []
+ pre_prediction = self.audio_decoder.prefill(
+ text_hidden_states,
+ prefix_embd,
+ cache=cache,
+ )
+ prediction = self.prediction_head.predict(
+ pre_prediction[:, [-1]], **sampling_params
+ )
+ prediction_embd = self.audio_embd(prediction)
+
+ for i in tqdm(range(max_seq_len)):
+ pre_prediction = self.audio_decoder.decode_one(
+ text_hidden_states,
+ prediction_embd,
+ cache,
+ crossatt_mask=crossatt_mask,
+ )
+ if unfold_ref:
+ pre_prediction, pre_prediction_ref = pre_prediction.chunk(2)
+ else:
+ pre_prediction_ref = None
+ prediction = self.prediction_head.predict(pre_prediction,
+ pre_prediction_ref=pre_prediction_ref,
+ **sampling_params,)
+ prediction_embd = self.audio_embd(prediction)
+ if unfold_ref:
+ prediction_embd = prediction_embd.repeat(2, 1, 1)
+ if teacher_force is not None:
+ b, n, d = teacher_force.shape
+ if i < n:
+ prediction_embd = self.audio_embd(teacher_force[:, [i]])
+
+ preds.append(prediction)
+ if self.stop_prediction_head is not None:
+ stop_pred = self.stop_prediction_head(pre_prediction).squeeze(1,2)
+ stop_signal = stop_pred > stop_threshold
+ stop_status += stop_signal
+ stop_idx[stop_signal * stop_idx > i] = i
+ if stop_status.prod():
+ if self.stop_token_embd is not None:
+ st_embd = self.stop_token_embd(
+ torch.ones(1, 1, device=device).int()
+ )
+ prediction_embd += st_embd
+ if not do_not_stop:
+ break
+ else:
+ print(f"STOP: {i}")
+
+ full_prediction = torch.cat(preds, dim=1)
+ full_prediction = [x[:stop_idx[i]][None] for i, x in enumerate(full_prediction.unbind())]
+
+ return cache, full_prediction
+
+
+"""
+ def generate_with_playhead(
+ self,
+ text_ids: torch.LongTensor,
+ prefix: torch.Tensor,
+ playhead_model: PlayHead,
+ selected_heads_idx: list[tuple[int, int]],
+ text_stop_tokens: torch.LongTensor | None = None,
+ text_mask: torch.Tensor | None = None,
+ text_rel_pos: torch.LongTensor | None = None,
+ teacher_force: torch.Tensor | None = None,
+ max_seq_len: int = 200,
+ device: str = "cuda",
+ sampling_params: dict | None = None,
+ stop_threshold: float = 0.5,
+ do_not_stop: bool = False,
+ width: tuple[int, int] = (5, 1),
+ abs_pos_start: int = 0,
+ stop_end_distance_threshold: int = 5,
+ ):
+ if sampling_params is None:
+ sampling_params = {}
+ if text_ids.ndim == 1:
+ text_ids = text_ids.unsqueeze(0)
+
+ input_text_embd = self.text_embd(text_ids)
+ if self.text_stop_token_embd is not None:
+ if text_stop_tokens is not None:
+ text_stop_tokens_embd = self.text_stop_token_embd(text_stop_tokens)
+ input_text_embd += text_stop_tokens_embd
+ text_hidden_states = self.text_encoder(
+ input_text_embd,
+ text_rel_pos=text_rel_pos,
+ mask=text_mask,
+ )
+ prefix_embd = self.audio_embd(prefix)
+
+ text_len = text_hidden_states.shape[1]
+
+ if self.sink_tokens is not None:
+ sink_tokens = self.sink_tokens[None, :].repeat(
+ text_hidden_states.shape[0], 1, 1
+ )
+ text_hidden_states = torch.cat(
+ (sink_tokens, text_hidden_states),
+ dim=1,
+ )
+
+ cache = self.audio_decoder.init_cache(max_seq_len, device)
+ preds = []
+ pre_prediction = self.audio_decoder.prefill(
+ text_hidden_states,
+ prefix_embd,
+ cache=cache,
+ )
+ text_freqs = None
+ prediction = self.prediction_head.predict(
+ pre_prediction[:, [-1]], **sampling_params
+ )
+ prediction_embd = self.audio_embd(prediction)
+ preds.append(prediction)
+
+ playhead_cache = playhead_model.init_cache()
+ previous_position = torch.zeros(1, device=device)
+ abs_pos = torch.ones(1, 1, device=device).long() * abs_pos_start
+
+ selected_heads_frame = collect_heads(cache, selected_heads_idx, last=False)
+ selected_heads_frame = selected_heads_frame.sum(1).transpose(-1, -2)
+
+ pos_preds = []
+ steps = []
+ expand_crossatt_mask = []
+
+ for i in tqdm(range(selected_heads_frame.shape[2])):
+ pred, step = playhead_model.predict(
+ selected_heads_frame[..., [i]],
+ cache=playhead_cache,
+ previous_position=previous_position,
+ )
+ previous_position = pred
+ abs_pos += step
+
+ pos_preds.append(pred)
+ steps.append(step)
+ exp_ca_mask = mask_from_abs_pos(
+ abs_pos,
+ (text_len // playhead_model.avg_pool_stride) + 1,
+ playhead_model.avg_pool_stride,
+ width=width,
+ )
+ exp_ca_mask = torch.nn.functional.pad(
+ exp_ca_mask, (self.num_sink_tokens, 0), value=True
+ ).bool()[..., : text_len + self.num_sink_tokens]
+ expand_crossatt_mask.append(exp_ca_mask)
+
+ print("starting at: ", abs_pos.item())
+
+ # pos_pred, step = playhead_model.predict(
+ # selected_heads_frame,
+ # cache=playhead_cache,
+ # previous_position=previous_position,
+ # )
+ # previous_position = pos_pred[:, [-1]]
+ # abs_pos += step
+ # exp_ca_mask = mask_from_abs_pos(
+ # abs_pos,
+ # (text_len // playhead_model.avg_pool_stride) + 1,
+ # playhead_model.avg_pool_stride,
+ # width=width,
+ # )
+ # expand_crossatt_mask.append(exp_ca_mask)
+ # steps.append(step)
+ # pos_preds.append(pos_pred)
+
+ progress_bar = tqdm(range(max_seq_len))
+ for i in progress_bar:
+ pre_prediction = self.audio_decoder.decode_one(
+ text_hidden_states,
+ prediction_embd,
+ cache,
+ # text_freqs=text_freqs,
+ crossatt_mask=exp_ca_mask,
+ )
+ prediction = self.prediction_head.predict(pre_prediction, **sampling_params)
+ prediction_embd = self.audio_embd(prediction)
+ if teacher_force is not None:
+ b, n, d = teacher_force.shape
+ if i < n:
+ prediction_embd = self.audio_embd(teacher_force[:, [i]])
+
+ ### PLAYHEAD ========================
+ selected_heads_frame = (
+ collect_heads(cache, selected_heads_idx).sum(1).transpose(-1, -2)
+ )
+ pos_pred, step = playhead_model.predict(
+ selected_heads_frame,
+ cache=playhead_cache,
+ previous_position=previous_position,
+ )
+ previous_position = pos_pred
+ abs_pos += step
+ exp_ca_mask = mask_from_abs_pos(
+ abs_pos,
+ (text_len // playhead_model.avg_pool_stride) + 1,
+ playhead_model.avg_pool_stride,
+ width=width,
+ )
+ exp_ca_mask = torch.nn.functional.pad(
+ exp_ca_mask, (self.num_sink_tokens, 0), value=True
+ ).bool()[..., : text_len + self.num_sink_tokens]
+ expand_crossatt_mask.append(exp_ca_mask)
+ steps.append(step)
+ pos_preds.append(pos_pred)
+ # =================================
+ preds.append(prediction)
+ if self.stop_prediction_head is not None:
+ stop_pred = self.stop_prediction_head(pre_prediction)
+ if stop_pred > stop_threshold:
+ dist = np.abs(
+ abs_pos.cpu().item() * playhead_model.avg_pool_stride - text_len
+ )
+ progress_bar.set_postfix(
+ {"stop": f"pos: {abs_pos.cpu().item()}; dist{dist}"}
+ )
+ if dist < stop_end_distance_threshold and not do_not_stop:
+ break
+
+ progress_bar.set_postfix({"position": abs_pos.cpu().item()})
+
+ full_prediction = torch.cat(preds, dim=1)
+ expand_crossatt_mask = torch.cat(expand_crossatt_mask, dim=1)
+ print(expand_crossatt_mask.shape)
+
+ return cache, full_prediction, expand_crossatt_mask, steps, pos_preds
+"""
diff --git a/tts/tts/groupdataset.py b/tts/tts/groupdataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7a7633636e9c0fc02dfaf501d694315990554bd
--- /dev/null
+++ b/tts/tts/groupdataset.py
@@ -0,0 +1,835 @@
+import math
+import random
+from abc import ABC, abstractmethod
+from bisect import bisect_left
+from functools import reduce
+from itertools import accumulate
+from os import wait
+from random import choice, choices, randint, sample
+from typing import List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import pytorch_lightning as ptl
+import torch
+from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
+from einops import rearrange, repeat
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import BatchSampler, DataLoader, Sampler, SubsetRandomSampler
+from transformers import PreTrainedTokenizerFast
+
+from tools import packmask_2d, pad_2d_sequence, sequence_mask
+
+
+def delay_rvq(
+ code,
+ head_token: int = -2,
+ tail_token: int = -3,
+):
+ q, _ = code.shape
+ extension = torch.ones((q, q + 1)).tril() * head_token
+ extension += torch.ones((q + 1, q)).tril(diagonal=-1).T * tail_token
+ extension = torch.flip(extension, (1,))
+ extended_code = torch.cat((code, extension), axis=1)
+ for i in range(q):
+ extended_code[i, :] = torch.roll(extended_code[i, :], i + 1)
+
+ return extended_code.long()
+
+
+class BucketSampler(Sampler[List[int]]):
+ def __init__(
+ self,
+ buckets: List[List[int]],
+ batch_sizes: List[int] | int,
+ bucket_sampling_weights: List[Tuple[float]] = None,
+ drop_last: bool = True,
+ distributed: bool = True, # TODO - implement not distributed as well
+ sample_bucket: Optional[int] = None,
+ seed: int = 123,
+ epoch_seed: bool = True,
+ ):
+ if type(batch_sizes) is int:
+ batch_sizes = [batch_sizes] * len(buckets)
+ else:
+ assert len(buckets) == len(batch_sizes)
+
+ if bucket_sampling_weights is not None:
+ assert len(bucket_sampling_weights) == len(batch_sizes)
+ self.bucket_sampling_weights = bucket_sampling_weights
+ self.num_replicas = torch.distributed.get_world_size()
+ self.rank = torch.distributed.get_rank()
+ self.buckets = [
+ b[self.rank : len(b) - len(b) % self.num_replicas : self.num_replicas]
+ for b in buckets
+ ]
+ self.num_samples = [len(b) // self.num_replicas for b in buckets]
+ self.batch_sizes = batch_sizes
+ self.total_sizes = [
+ ns // bs for ns, bs in zip(self.num_samples, self.batch_sizes)
+ ]
+ self.drop_last = drop_last
+ self.seed = seed
+ self.epoch = 0
+ self.sample_bucket = sample_bucket
+ self.epoch_seed = epoch_seed
+ self.batch_size = batch_sizes[0]
+
+ def set_epoch(self, epoch: int):
+ self.epoch = epoch
+
+ def __len__(self):
+ return sum(self.total_sizes)
+
+ def __iter__(self):
+ generator = torch.Generator()
+ generator.manual_seed(self.seed + self.epoch * self.epoch_seed + self.rank)
+ pool = [
+ BatchSampler(
+ SubsetRandomSampler(b, generator=generator),
+ bs,
+ drop_last=self.drop_last,
+ )
+ for b, bs in zip(self.buckets, self.batch_sizes)
+ ]
+ pool = [iter(b) for b in pool]
+ weights = (
+ [w for w in self.bucket_sampling_weights]
+ if self.bucket_sampling_weights is not None
+ else None
+ )
+ while pool: # sample until all buckets are done
+ idx, bucket = choices(list(enumerate(pool)), weights=weights)[0]
+ try:
+ batch = next(bucket)
+ yield batch
+ except StopIteration:
+ pool.pop(idx) # if bucket is done, throw it
+ if weights is not None:
+ weights.pop(idx)
+
+
+class DatasetFactory(ABC):
+ @abstractmethod
+ def build(self):
+ pass
+
+
+def to_segments_dataset(dataset):
+ dataset = dataset.map(
+ lambda x, y: {"segments": [{"start": 0.0, "end": y.shape[-1] / 75, "text": x}]},
+ input_columns=["text", "audio_token"],
+ )
+ return dataset
+
+
+class LibriTTS_SyntheticContinuation(DatasetFactory):
+ def __init__(
+ self, path, min_dur=0.0, max_dur=30.0, max_cer=0.1, llm_long=True, repeat=1
+ ):
+ self.path = path
+ self.min_dur = min_dur
+ self.max_dur = max_dur
+ self.max_cer = max_cer
+ self.llm_long = llm_long
+
+ def build(self):
+ ds = load_from_disk(self.path)
+ ds = ds.filter(lambda x: x < self.max_cer, input_columns="cer")
+ ds = ds.map(
+ lambda x: {"audio_duration": len(x[0][0]) / 75}, input_columns="audio_token"
+ )
+ ds = ds.filter(
+ lambda x: x < self.max_dur and x > self.min_dur,
+ input_columns="audio_duration",
+ ).select_columns(["segments", "audio_token", "audio_duration"])
+ if self.llm_long:
+ llm_long = (
+ LibriLight_Medium_Long_WavTokenizer()
+ .build()["train"]
+ .shuffle()
+ .take(100000)
+ )
+ llm_long = llm_long.map(
+ lambda x: {
+ "segments": [
+ {k: v[0] if type(v) is list else v for k, v in xx.items()}
+ for xx in x
+ ]
+ },
+ input_columns="segments",
+ )
+ # splits = ["train.clean.100", "train.clean.360", "train.other.500"]
+ # ltts = load_dataset("theodorr/lttsr_wavtokenizer", split=splits) + load_dataset("theodorr/ltts_wavtokenizer", split=splits)
+ # ltts = [x.map(lambda x: {"audio_duration": len(x[0][0])/75}, input_columns="audio_token").select_columns(["text_normalized", "audio_token", "audio_duration"]).rename_column("text_normalized", "text").filter(lambda dur: dur > self.min_dur and dur < self.max_dur, input_columns="audio_duration") for x in ltts]
+ # mls10k_nemo = load_dataset("theodorr/mls10k_nemo", split="train").select_columns(["text", "audio_duration"])
+ # mls10k_wavtokenizer = load_dataset("theodorr/mls10k_wavtokenizer", split="train").select_columns(["audio_token"])
+ # mls = concatenate_datasets([mls10k_nemo, mls10k_wavtokenizer], axis=1).shuffle().take(100000)
+ ds = concatenate_datasets([ds, llm_long], axis=0).select_columns(
+ ["segments", "audio_token", "audio_duration"]
+ )
+
+ # side_datasets = concatenate_datasets(ltts, axis=0)
+ # side_datasets = to_segments_dataset(side_datasets)
+
+ ds = DatasetDict({"train": ds})
+
+ return ds
+
+
+class LibriLight_Medium_Long_WavTokenizer(DatasetFactory):
+ def __init__(self, min_dur=40.0, max_dur=60.0, min_segments=2):
+ self.min_dur = min_dur
+ self.max_dur = max_dur
+ self.min_segments = min_segments
+
+ def build(self):
+ ds = load_dataset("theodorr/llm_nemo_wavtokenizer_long")
+ ds = ds.filter(
+ lambda x: x < self.max_dur and x > self.min_dur,
+ input_columns="audio_duration",
+ )
+ ds = ds.filter(lambda x: len(x) >= self.min_segments, input_columns="segments")
+ return ds
+
+
+class MLS_LTTS_SelfCheckpoint_WavTokenizer(DatasetFactory):
+ def __init__(self, path, min_dur=3.0, max_dur=26.0):
+ self.min_dur = min_dur
+ self.max_dur = max_dur
+ self.text_dur_path = path
+
+ def build(self):
+ mls10k_nemo = load_dataset(
+ "theodorr/mls10k_nemo", split="train"
+ ).select_columns(["text", "audio_duration"])
+ mls10k_wavtokenizer = load_dataset(
+ "theodorr/mls10k_wavtokenizer", split="train"
+ ).select_columns(["audio_token"])
+ mls10k_text_dur = load_from_disk(self.text_dur_path + "/mls10knemo_align_ds")
+ mls = concatenate_datasets(
+ [mls10k_nemo, mls10k_wavtokenizer, mls10k_text_dur], axis=1
+ )
+
+ splits = ["train.clean.100", "train.clean.360", "train.other.500"]
+
+ ltts = load_dataset("theodorr/ltts_wavtokenizer")
+ lttsr = load_dataset("theodorr/lttsr_wavtokenizer")
+
+ ltts_text_dur = load_from_disk(self.text_dur_path + "/ltts_align_ds")
+ lttsr_text_dur = load_from_disk(self.text_dur_path + "/lttsr_align_ds")
+
+ ltts = [
+ concatenate_datasets([ltts[split], ltts_text_dur[split]], axis=1)
+ for split in splits
+ ]
+ lttsr = [
+ concatenate_datasets([lttsr[split], lttsr_text_dur[split]], axis=1)
+ for split in splits
+ ]
+
+ ltts = ltts + lttsr
+
+ ltts = [
+ x.map(
+ lambda x: {"audio_duration": len(x[0][0]) / 75},
+ input_columns="audio_token",
+ )
+ .select_columns(
+ [
+ "text_normalized",
+ "audio_token",
+ "audio_duration",
+ "text_token_duration",
+ ]
+ )
+ .rename_column("text_normalized", "text")
+ .filter(
+ lambda dur: dur > self.min_dur and dur < self.max_dur,
+ input_columns="audio_duration",
+ )
+ for x in ltts
+ ]
+
+ mls = concatenate_datasets([mls] + ltts, axis=0)
+
+ return DatasetDict({"train": mls})
+
+
+class MLS_SelfCheckpoint_WavTokenizer(DatasetFactory):
+ def __init__(self, path, min_dur=3.0, max_dur=26.0):
+ self.min_dur = min_dur
+ self.max_dur = max_dur
+ self.text_dur_path = path
+
+ def build(self):
+ mls10k_nemo = load_dataset(
+ "theodorr/mls10k_nemo", split="train"
+ ).select_columns(["text", "audio_duration"])
+ mls10k_wavtokenizer = load_dataset(
+ "theodorr/mls10k_wavtokenizer", split="train"
+ ).select_columns(["audio_token"])
+ mls10k_text_dur = load_from_disk(self.text_dur_path)
+ mls = concatenate_datasets(
+ [mls10k_nemo, mls10k_wavtokenizer, mls10k_text_dur], axis=1
+ )
+ return DatasetDict({"train": mls})
+
+
+class Synthetic_Expresso(DatasetFactory):
+ def __init__(
+ self,
+ path,
+ min_dur=0.0,
+ max_dur=23.0,
+ keep_ltts=False,
+ cfg_scale: list[float] = None,
+ ):
+ self.min_dur = min_dur
+ self.max_dur = max_dur
+ self.path = path
+ self.keep_ltts = keep_ltts
+ self.cfg_scale = cfg_scale
+
+ def build(self):
+ ds = load_from_disk(self.path)
+ if self.cfg_scale is not None:
+ self.cfg_scale.sort()
+ print(self.cfg_scale)
+ ds = ds.map(
+ lambda x: {"abs_style_intensity": bisect_left(self.cfg_scale, x)},
+ input_columns="cfg",
+ )
+ ds = ds.filter(
+ lambda x: x < self.max_dur and x > self.min_dur,
+ input_columns="audio_duration",
+ )
+ if self.keep_ltts:
+ splits = ["train.clean.360"]
+ ltts = load_dataset(
+ "theodorr/lttsr_wavtokenizer", split=splits
+ ) + load_dataset("theodorr/ltts_wavtokenizer", split=splits)
+ ltts = [
+ x.map(
+ lambda x: {"audio_duration": len(x[0][0]) / 75},
+ input_columns="audio_token",
+ )
+ .select_columns(["text_normalized", "audio_token", "audio_duration"])
+ .rename_column("text_normalized", "text")
+ .filter(
+ lambda dur: dur > self.min_dur and dur < self.max_dur,
+ input_columns="audio_duration",
+ )
+ for x in ltts
+ ]
+ ds = concatenate_datasets([ds, *ltts], axis=0)
+
+ return {"train": ds}
+
+
+class HiFiTTS2_AudioLatent(DatasetFactory):
+ def __init__(
+ self,
+ path: str = "hifitts2_vae8_dataset",
+ min_dur: float = 3.0,
+ max_dur: float = 20.1,
+ framerate: float = 25.0,
+ ):
+ self.min_dur = min_dur
+ self.max_dur = max_dur
+ self.path = path
+ self.framerate = framerate
+
+ def build(self):
+ dataset = load_from_disk(self.path).with_format("torch")
+ dataset = dataset.map(
+ lambda x: {"audio_duration": x.shape[1] / self.framerate},
+ input_columns="audio_latent",
+ ).filter(
+ lambda dur: dur > self.min_dur and dur < self.max_dur,
+ input_columns="audio_duration",
+ )
+ return DatasetDict({"train": dataset})
+
+
+class MLS_LibriTTS_WavTokenizer(DatasetFactory):
+ def __init__(self, min_dur=3.0, max_dur=20.1):
+ self.min_dur = min_dur
+ self.max_dur = max_dur
+
+ def build(self):
+ splits = ["train.clean.100", "train.clean.360", "train.other.500"]
+ ltts = load_dataset("theodorr/lttsr_wavtokenizer", split=splits)
+ lttsr = load_dataset("theodorr/ltts_wavtokenizer", split=splits)
+ ltts += lttsr
+ ltts = [
+ x.map(
+ lambda x: {"audio_duration": len(x[0][0]) / 75},
+ input_columns="audio_token",
+ )
+ .select_columns(["text_normalized", "audio_token", "audio_duration"])
+ .rename_column("text_normalized", "text")
+ .filter(
+ lambda dur: dur > self.min_dur and dur < self.max_dur,
+ input_columns="audio_duration",
+ )
+ for x in ltts
+ ]
+ mls10k_nemo = load_dataset(
+ "theodorr/mls10k_nemo", split="train"
+ ).select_columns(["text", "audio_duration"])
+ mls10k_wavtokenizer = load_dataset(
+ "theodorr/mls10k_wavtokenizer", split="train"
+ ).select_columns(["audio_token"])
+ mls = concatenate_datasets([mls10k_nemo, mls10k_wavtokenizer], axis=1)
+ return DatasetDict({"train": concatenate_datasets([mls] + ltts, axis=0)})
+
+
+# def standalone_collate(batch, tokenizer):
+# audio_token, text = zip(*[(x["audio_token"], x["text"]) for x in batch])
+# #audio_token = [
+# # delay_rvq(
+# # x.squeeze() + 3,
+# # head_token=1,
+# # tail_token=2,
+# # ).T
+# # for x in audio_token
+# # ]
+# text_token = [torch.LongTensor(tokenizer.encode("[BOS]" + x + "[EOS]")) for x in text]
+# xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_token))
+# x_mask, y_mask = map(lambda x: sequence_mask(x, device="cpu"), (torch.tensor(xlen), torch.tensor(ylen)))
+#
+# audio_token, text_token = map(lambda x: pad_sequence(x, batch_first=True, padding_value=0), (audio_token, text_token))
+#
+# encoder_mask = (x_mask.unsqueeze(1) * x_mask.unsqueeze(2))
+# crossatt_mask = (x_mask.unsqueeze(1) * y_mask.unsqueeze(2))
+#
+#
+#
+# return {
+# "text_token": text_token,
+# "audio_token": audio_token,
+# "crossatt_mask": crossatt_mask,
+# "encoder_mask": encoder_mask,
+# "y_mask": y_mask,
+# "x_len": xlen,
+# "y_len": ylen,
+# }
+def standalone_collate_latent(batch, tokenizer, abs_style_intensity=False):
+ audio_latent, text = zip(*[(x["audio_latent"], x["text"]) for x in batch])
+ audio_latent = [x.squeeze() for x in audio_latent]
+ text_token = [
+ torch.LongTensor(tokenizer.encode("[BOS]" + x + "[EOS]")) for x in text
+ ]
+ xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_latent))
+ x_mask, y_mask = map(
+ lambda x: sequence_mask(x, device="cpu"),
+ (torch.tensor(xlen), torch.tensor(ylen)),
+ )
+
+ audio_latent, text_token = map(
+ lambda x: pad_sequence(x, batch_first=True, padding_value=0.0),
+ (audio_latent, text_token),
+ )
+
+ encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2)
+ crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2)
+
+ if abs_style_intensity:
+ abs_style_intensity = [x["abs_style_intensity"] for x in batch]
+ abs_style_intensity = [
+ torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity
+ ]
+ abs_style_intensity = torch.stack(abs_style_intensity)
+ else:
+ abs_style_intensity = None
+
+ return {
+ "text_token": text_token,
+ "audio_token": audio_latent,
+ "crossatt_mask": crossatt_mask,
+ "encoder_mask": encoder_mask,
+ "y_mask": y_mask,
+ "x_len": xlen,
+ "y_len": ylen,
+ "abs_style_intensity": abs_style_intensity,
+ "text_rel_pos": None,
+ "audio_rel_pos": None,
+ "crossatt_pos": None,
+ }
+
+
+def standalone_collate(batch, tokenizer, abs_style_intensity=False):
+ audio_token, text = zip(*[(x["audio_token"], x["text"]) for x in batch])
+
+ audio_token = [
+ torch.cat(
+ (
+ torch.ones(*a_t.shape[:-1], 1),
+ a_t + 3,
+ torch.ones(*a_t.shape[:-1], 1) * 2,
+ ),
+ dim=-1,
+ )
+ .squeeze(0)
+ .transpose(-1, -2)
+ .long()
+ for a_t in audio_token
+ ]
+
+ text_token = [
+ torch.LongTensor(tokenizer.encode("[BOS]" + x + "[EOS]")) for x in text
+ ]
+ xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_token))
+ x_mask, y_mask = map(
+ lambda x: sequence_mask(x, device="cpu"),
+ (torch.tensor(xlen), torch.tensor(ylen)),
+ )
+
+ audio_token, text_token = map(
+ lambda x: pad_sequence(x, batch_first=True, padding_value=0),
+ (audio_token, text_token),
+ )
+
+ encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2)
+ crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2)
+
+ if abs_style_intensity:
+ abs_style_intensity = [x["abs_style_intensity"] for x in batch]
+ abs_style_intensity = [
+ torch.zeros(1).long()[0] if x.isnan() else x for x in abs_style_intensity
+ ]
+ abs_style_intensity = torch.stack(abs_style_intensity)
+ else:
+ abs_style_intensity = None
+
+ return {
+ "text_token": text_token,
+ "audio_token": audio_token,
+ "crossatt_mask": crossatt_mask,
+ "encoder_mask": encoder_mask,
+ "y_mask": y_mask,
+ "x_len": xlen,
+ "y_len": ylen,
+ "abs_style_intensity": abs_style_intensity,
+ "text_rel_pos": None,
+ "audio_rel_pos": None,
+ "crossatt_pos": None,
+ }
+
+
+def random_log_breakpoints(seq: Sequence, a: int, b: int) -> List[int]:
+ """
+ Generate random breakpoints in a sequence where the gap X between
+ successive breakpoints satisfies log2(X) ~ Uniform[log2(a), log2(b)].
+ Gaps are then rounded to the nearest integer in [a, b].
+
+ Parameters
+ ----------
+ seq : Sequence
+ The input sequence in which to place breakpoints.
+ a : int
+ Minimum gap (>= 1).
+ b : int
+ Maximum gap (>= a).
+
+ Returns
+ -------
+ List[int]
+ Sorted list of breakpoint indices (0 < idx < len(seq)).
+ """
+ if a < 1 or b < a:
+ raise ValueError("Require 1 <= a <= b")
+ n = len(seq)
+ breakpoints: List[int] = []
+ pos = 0
+
+ while True:
+ # sample U ~ Uniform(log2(a), log2(b))
+ u = random.uniform(math.log2(a), math.log2(b))
+ # map back to X = 2^U, then round to nearest integer
+ x = 2**u
+ gap = int(math.floor(x + 0.5))
+
+ # enforce integer bounds exactly
+ gap = max(a, min(b, gap))
+
+ pos += gap
+ if pos >= n:
+ break
+ breakpoints.append(pos)
+
+ return breakpoints
+
+
+def random_segments_from_text_and_durations(text_dur, max_pivots, codec_rate_hz):
+ text, dur = map(list, zip(*text_dur))
+ # k = choice(range(1, max_pivots))
+ # b = sample(range(2, len(text) - 2), k=max(min(k, len(text) - 5), 1))
+ # b.sort()
+ b = random_log_breakpoints(text, 4, 384)
+ bounds = [0] + b + [len(text)]
+ cum_dur = list(accumulate([0] + dur, int.__add__))
+ segs, durs = [], []
+ for a, b in zip(bounds[:-1], bounds[1:]):
+ segs.append(text[a:b])
+ durs.append(sum(dur[a:b]))
+ bounds = [0] + list(accumulate(durs, int.__add__))
+ segs_dicts = []
+ for t, s, e in zip(segs, bounds[:-1], bounds[1:]):
+ segs_dicts.append(
+ {
+ "start": s,
+ "end": e,
+ "text_token": t,
+ }
+ )
+ segs_dicts[-1]["end"] += 1
+ return segs_dicts
+
+
+def segments_collate(
+ batch,
+ tokenizer,
+ codec_rate_hz=75,
+ block_crossatt_mask=True,
+ text_rel_pos=True,
+ stop_token=2,
+ from_token_durations=True,
+ max_segements=10,
+):
+ # for x in batch:
+ # if not x.has_key("segments"):
+ # if x.has_key("alignments"):
+
+ # pass
+ # else:
+ # x["segments"] = [{"start": 0., "end": x["audio_token"].shape[-1]/codec_rate_hz, "text": x["text"]}]
+
+ segments = "text_token_duration" if from_token_durations else "segments"
+ audio_token, segments = zip(*[(x["audio_token"], x[segments]) for x in batch])
+ if from_token_durations:
+ segments = [
+ random_segments_from_text_and_durations(x.tolist(), 8, codec_rate_hz)
+ for x in segments
+ ]
+
+ bos, eos = map(tokenizer.encode, ("[BOS]", "[EOS]"))
+ audio_segments = []
+ text_segments = []
+ audio_segments_len = []
+ text_segments_len = []
+
+ for aud, seg in zip(audio_token, segments):
+ tt, at, tt_l, at_l = [], [], [], []
+
+ for i, s in enumerate(seg):
+ if from_token_durations:
+ ttoken = bos + s["text_token"] + eos
+ else:
+ ttoken = tokenizer.encode("[BOS]" + s["text"] + "[EOS]")
+ tt.append(ttoken)
+ sr = codec_rate_hz
+ a_s = aud[..., int(s["start"]) : int(s["end"])] + 3
+ # if i < len(seg) - 1:
+ a_s = torch.cat((a_s, torch.ones(*a_s.shape[:-1], 1) * stop_token), dim=-1)
+ at.append(a_s)
+ at_l.append(a_s.shape[-1])
+ tt_l.append(len(ttoken))
+
+ # at_l[0], at_l[-1] = at_l[0] + 1, at_l[-1] + 1
+ at_l[0] = at_l[0] + 1
+
+ audio_segments.append(at)
+ text_segments.append(tt)
+ audio_segments_len.append(at_l)
+ text_segments_len.append(tt_l)
+
+ text_token = [torch.LongTensor(reduce(list.__add__, x)) for x in text_segments]
+ audio_token = [
+ torch.cat(
+ (
+ torch.ones(*a_s.shape[:-1], 1),
+ *a_ss,
+ # torch.ones(*a_s.shape[:-1], 1) * stop_token,
+ ),
+ dim=-1,
+ )
+ .squeeze(0)
+ .transpose(-1, -2)
+ .long()
+ for a_ss in audio_segments
+ ]
+
+ xlen, ylen = map(lambda x: [xx.shape[0] for xx in x], (text_token, audio_token))
+ x_mask, y_mask = map(
+ lambda x: sequence_mask(x, device="cpu"),
+ (torch.tensor(xlen), torch.tensor(ylen)),
+ )
+
+ audio_token, text_token = map(
+ lambda x: pad_sequence(x, batch_first=True, padding_value=0),
+ (audio_token, text_token),
+ )
+
+ encoder_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(2)
+ if block_crossatt_mask:
+ crossatt_mask = [
+ packmask_2d(x, y) for x, y in zip(audio_segments_len, text_segments_len)
+ ]
+ crossatt_mask = pad_2d_sequence(crossatt_mask)
+ pad_mask = rearrange(torch.arange(max(ylen)), "n -> 1 n 1") >= rearrange(
+ torch.tensor(ylen), "n -> n 1 1"
+ )
+ crossatt_mask += pad_mask
+ else:
+ crossatt_mask = x_mask.unsqueeze(1) * y_mask.unsqueeze(2)
+
+ text_rel_pos = pad_sequence(
+ [torch.cat([torch.arange(x) for x in tsl]) for tsl in text_segments_len],
+ batch_first=True,
+ )
+ audio_rel_pos = pad_sequence(
+ [torch.cat([torch.arange(x) for x in asl]) for asl in audio_segments_len],
+ batch_first=True,
+ )
+
+ return {
+ "text_token": text_token,
+ "segments": segments,
+ "audio_token": audio_token,
+ "crossatt_mask": crossatt_mask,
+ "text_rel_pos": text_rel_pos,
+ "abs_style_intensity": None,
+ "audio_rel_pos": audio_rel_pos,
+ "crossatt_pos": None,
+ "encoder_mask": encoder_mask,
+ "y_mask": y_mask,
+ "x_len": xlen,
+ "y_len": ylen,
+ }
+
+
+class LinaDataModule(ptl.LightningDataModule):
+ def __init__(
+ self,
+ path: str | DatasetFactory,
+ quant_layer: list[int],
+ train_batch_size: int = 8,
+ token_by_batch: int | None = None,
+ n_buckets=5,
+ codec_rate_hz: int = 75,
+ num_workers: int = 8,
+ test_size: int = 2000,
+ val_batch_size: int = 8,
+ seed: int = 123,
+ train_test_seed: int = 123,
+ segments: bool = False,
+ tokenizer_file=None,
+ trail_end_frame: int | None = None,
+ split="train",
+ add_columns: str | list[str] | None = None,
+ add_text_tokens: list[str] | None = None,
+ type: str = "latent",
+ ):
+ super().__init__()
+
+ self.path = path
+ self.codec_rate_hz = codec_rate_hz
+ self.num_workers = num_workers
+ self.quant_layer = quant_layer
+ self.seed = seed
+ self.segments = segments
+ self.train_test_seed = train_test_seed
+ self.test_size = test_size
+ self.val_batch_size = val_batch_size
+ self.train_batch_size = train_batch_size
+ self.split = split
+ self.trail_end_frame = trail_end_frame
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
+ if add_text_tokens:
+ self.tokenizer.add_tokens(add_text_tokens)
+ self.add_columns = add_columns
+ self.n_buckets = n_buckets
+ self.token_by_batch = token_by_batch
+ self.type = type
+
+ def setup(self, stage):
+ if isinstance(self.path, DatasetFactory):
+ self.dataset = self.path.build()
+ else:
+ self.dataset = load_dataset(self.path)
+
+ split = self.split
+ columns = [
+ "audio_latent" if self.type == "latent" else "audio_token",
+ "text",
+ "audio_duration",
+ ]
+ if self.add_columns is not None:
+ if type(self.add_columns) is str:
+ self.add_columns = [self.add_columns]
+ columns += self.add_columns
+
+ if self.segments:
+ # columns += ["segments"]
+ columns += ["text_token_duration"]
+ columns.remove("text")
+ self.collate_fn = lambda x: segments_collate(x, self.tokenizer)
+ else:
+ self.collate_fn = lambda x: standalone_collate(
+ x, self.tokenizer, abs_style_intensity="abs_style_intensity" in columns
+ )
+ self.dataset = (
+ self.dataset[split]
+ .train_test_split(test_size=self.test_size, seed=self.train_test_seed)
+ .select_columns(columns)
+ )
+
+ if self.type == "latent":
+ self.collate_fn = lambda x: standalone_collate_latent(x, self.tokenizer)
+
+ def get_buckets_by_quantile(duration, n_quantile):
+ idxdur = list(enumerate(duration))
+ idxdur.sort(key=lambda x: x[1])
+ idx, dur = zip(*idxdur)
+ bucket_size = len(idx) // n_quantile
+ buckets = [list(x) for x in zip(*[iter(idx)] * bucket_size)]
+ return buckets
+
+ if self.token_by_batch is not None:
+ train_buckets = get_buckets_by_quantile(
+ self.dataset["train"]["audio_duration"], self.n_buckets
+ )
+ max_audio_durations = [
+ self.dataset["train"]["audio_duration"][x[-1]] for x in train_buckets
+ ]
+
+ batch_sizes = [
+ int(self.token_by_batch // (self.codec_rate_hz * ad))
+ for ad in max_audio_durations
+ ]
+ self.train_batch_sampler = BucketSampler(train_buckets, batch_sizes)
+
+ def train_dataloader(self):
+ if self.token_by_batch is not None:
+ return DataLoader(
+ self.dataset["train"].with_format("torch"),
+ num_workers=self.num_workers,
+ collate_fn=self.collate_fn,
+ batch_sampler=self.train_batch_sampler,
+ )
+ else:
+ return DataLoader(
+ self.dataset["train"].with_format("torch"),
+ num_workers=self.num_workers,
+ batch_size=self.train_batch_size,
+ collate_fn=self.collate_fn,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.dataset["test"].with_format("torch"),
+ batch_size=self.val_batch_size,
+ num_workers=0,
+ collate_fn=self.collate_fn,
+ )
diff --git a/tts/tts/layers/__init__.py b/tts/tts/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/tts/tts/layers/attention.py b/tts/tts/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fdf095bafe0ca44ce825f2b2acfb5ef0596dcff
--- /dev/null
+++ b/tts/tts/layers/attention.py
@@ -0,0 +1,230 @@
+import math
+import os
+import time
+from typing import Literal
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, reduce, repeat
+from fla.models.utils import Cache
+from torch import nn
+from transformers.cache_utils import Cache
+
+
+def apply_causal_sliding_window(mask: torch.Tensor, window_size: int) -> torch.Tensor:
+ B, H, Q, KV = mask.shape
+ device = mask.device
+
+ q_idx = torch.arange(Q, device=device).unsqueeze(1) # (Q, 1)
+ k_idx = torch.arange(KV, device=device).unsqueeze(0) # (1, KV)
+
+ lower_bound = q_idx - (window_size - 1) # (Q, 1), may be negative
+ allowed_2d = (k_idx <= q_idx) & (k_idx >= lower_bound) # (Q, KV), dtype=torch.bool
+
+ allowed_4d = allowed_2d.unsqueeze(0).unsqueeze(0).expand(B, H, Q, KV)
+
+ orig_dtype = mask.dtype
+ if mask.dtype != torch.bool:
+ mask_bool = mask.to(torch.bool)
+ else:
+ mask_bool = mask
+
+ new_mask = mask_bool & allowed_4d
+
+ if orig_dtype != torch.bool:
+ return new_mask.to(orig_dtype)
+ else:
+ return new_mask
+
+
+def precompute_freqs_cis(
+ t: torch.Tensor,
+ n_elem: int,
+ base: float = 10000,
+) -> torch.Tensor:
+ freqs = 1.0 / (
+ base
+ ** (
+ torch.arange(0, n_elem, 2, device=t.device)[: (n_elem // 2)].float()
+ / n_elem
+ )
+ )
+ freqs = torch.outer(t, freqs)
+ cache = repeat(freqs, "... d -> ... (d 2)")
+ return cache
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return rearrange(x, "... d r -> ... (d r)")
+
+
+def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ out = x * freqs_cis.cos() + rotate_half(x) * freqs_cis.sin()
+ return out
+
+
+def scaled_dot_product_attention(query, key, value, mask=None):
+ scale_factor = 1 / math.sqrt(query.size(-1))
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ if mask is not None:
+ attn_weight.masked_fill_(~mask, -torch.finfo(attn_weight.dtype).max)
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ return attn_weight @ value, attn_weight
+
+
+class SelfAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ layer_idx: int,
+ is_causal: bool = False,
+ sliding_window: int | None = None,
+ ):
+ super().__init__()
+ self.qkv = nn.Linear(dim, 3 * dim)
+ assert dim % num_heads == 0
+ self.heads = num_heads
+ self.is_causal = is_causal
+ self.layer_idx = layer_idx
+ self.output_proj = nn.Linear(dim, dim)
+ self.sliding_window = sliding_window
+ if self.sliding_window is not None:
+ self.is_causal = False
+
+ def forward(
+ self,
+ x,
+ freqs: torch.Tensor | None = None,
+ mask: torch.Tensor | None = None,
+ cache: Cache | None = None,
+ ):
+ B, T, D = x.shape
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
+ q, k, v = map(
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
+ )
+
+ if freqs is not None:
+ q = apply_rotary_emb(q, freqs)
+ k = apply_rotary_emb(k, freqs)
+ if cache is not None:
+ cache.update(attn_state=(k, v), layer_idx=self.layer_idx)
+ k, v = cache[self.layer_idx]["attn_state"]
+
+ if self.sliding_window is not None:
+ mask = torch.ones(B, 1, T, T, device=x.device)
+ mask = apply_causal_sliding_window(mask, self.sliding_window)
+
+ y = F.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask, is_causal=self.is_causal and T > 1
+ )
+ y = rearrange(y, "b h n d -> b n (h d)")
+ y = self.output_proj(y)
+ return y
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ layer_idx: int | None = None,
+ dropout: float = 0.1,
+ ):
+ super().__init__()
+ assert dim % num_heads == 0
+ self.pre_norm_q = nn.LayerNorm(dim)
+ self.q = nn.Linear(dim, dim)
+ self.k = nn.Linear(dim, dim)
+ self.v = nn.Linear(dim, dim)
+ self.layer_idx = layer_idx
+ self.heads = num_heads
+ self.dropout_att = dropout
+
+ def _prepare_kv(self, text_hidden_states: torch.Tensor):
+ v = self.ln_v(self.v(text_hidden_states))
+ k = self.ln_k(self.k(text_hidden_states))
+
+ def _query(self, x):
+ return self.q(self.pre_norm_q(q))
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor | None = None,
+ v: torch.Tensor | None = None,
+ mask: torch.Tensor | None = None,
+ output_attention: bool = False,
+ cache: Cache | None = None,
+ **kwargs,
+ ):
+ if v is None:
+ v = k
+ q = self.q(self.pre_norm_q(q))
+ if cache is not None:
+ if cache[self.layer_idx] is not None:
+ ca_state = cache[self.layer_idx]["crossatt_state"]
+ if ca_state is not None:
+ k, v = ca_state
+ else:
+ v = self.v(v)
+ k = self.k(k)
+ cache.update(crossatt_state=(k, v), layer_idx=self.layer_idx)
+ else:
+ v = self.v(v)
+ k = self.k(k)
+ q, k, v = map(
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.heads), (q, k, v)
+ )
+ if not self.training:
+ x, att = scaled_dot_product_attention(
+ q, k, v, mask=mask[:, None] if mask is not None else None
+ )
+ else:
+ x = nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask[:, None], dropout_p=self.dropout_att
+ )
+ att = None
+ x = rearrange(x, "b h n d -> b n (h d)")
+ if att is not None:
+ if cache is not None:
+ cache.update(crossatt_weights=att, layer_idx=self.layer_idx)
+ else:
+ self.att = att
+ return x
+
+
+class ConvPos(nn.Module):
+ def __init__(self, dim, max_seq_len=2000, kernel_size=7):
+ super().__init__()
+ self.embed = nn.Embedding(max_seq_len, dim)
+ self.dw_conv = nn.Conv1d(dim, dim, kernel_size, groups=dim, padding="same")
+
+ def forward(self, x, left_shift=0):
+ # left_pad = 31 if left_shift > 0 else 0
+ # x = torch.cat((torch.arange(left_shift - left_pad, left_shift).to(x).unsqueeze(0),x, torch.arange(31).to(x).unsqueeze(0)), dim=1).clamp_min_(0)
+ y = self.embed(x)
+ y = rearrange(y, "b n c -> b c n")
+ y = self.dw_conv(y)
+ y = rearrange(y, "b c n -> b n c") # [:,left_pad:-31]
+ return y
+
+
+class SinPos(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ exp = torch.arange(self.dim // 2, device=x.device)
+ exp = 2 * exp / (self.dim)
+ exp = rearrange(exp, "e -> 1 1 e")
+ x = rearrange(x, "b p -> b p 1")
+ pos = x * torch.pow(10000, -exp)
+ pos = torch.cat((pos, pos + math.pi / 2), dim=2)
+ pos = torch.sin(pos)
+ return pos
diff --git a/tts/tts/layers/ffn.py b/tts/tts/layers/ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5133855ba99e3b0be9257c3383edbceef49514fd
--- /dev/null
+++ b/tts/tts/layers/ffn.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class SwiGLU(nn.Module):
+ def __init__(self, d_model: int, ffn_expansion_factor: int = 4):
+ super().__init__()
+ self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2)
+ self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model)
+
+ def forward(self, x):
+ gate, x = self.p_in(x).chunk(2, dim=-1)
+ return self.p_out(nn.functional.silu(gate) * x)
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+class GaussianFourierTimeEmbedding(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(dim), requires_grad=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x[:, None] * self.weight[None, :] * 2 * torch.pi
+ x = torch.cat((torch.sin(x), torch.cos(x)), dim=1)
+ return x
+
+
+class AdaLNFinalLayer(nn.Module):
+ def __init__(self, hidden_dim, feature_dim):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_dim, feature_dim, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_dim, 2 * hidden_dim, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+
+class AdaLNMLP(nn.Module):
+ def __init__(self, hidden_dim):
+ super().__init__()
+ self.hidden_dim = hidden_dim
+
+ self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False)
+ self.mlp = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
+ )
+
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(), nn.Linear(hidden_dim, 3 * hidden_dim, bias=True)
+ )
+
+ def forward(self, x, y):
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
+ h = self.mlp(h)
+ return x + gate_mlp * h
diff --git a/tts/tts/model/__init__.py b/tts/tts/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..29c6330d7c704d41d9874637cd36ca123a8e51db
--- /dev/null
+++ b/tts/tts/model/__init__.py
@@ -0,0 +1,2 @@
+from .simple_gla import SimpleGLADecoder
+from .transformer import TransformerDecoder, TransformerEncoder
diff --git a/tts/tts/model/cache.py b/tts/tts/model/cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fd84372fab1fe59bf02c0a75c55482084558351
--- /dev/null
+++ b/tts/tts/model/cache.py
@@ -0,0 +1,308 @@
+from typing import Any
+
+import torch
+from transformers.cache_utils import Cache, _static_cache_update
+
+
+class StaticCache(Cache):
+ """
+ Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
+
+ Parameters:
+ config (`PretrainedConfig`):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
+ smaller batch size is used. If you are manually setting the batch size, make sure to take into account the
+ number of beams if you are running beam search
+ max_cache_len (`int`, *optional*):
+ The maximum sequence length with which the model will be used.
+ device (`torch.device` or `str`, *optional*):
+ The device on which the cache should be initialized. If you're using more than 1 computation device, you
+ should pass the `layer_device_map` argument instead.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ The default `dtype` to use when initializing the layer.
+ layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
+ Mapping between the layers and its device. This is required when you are manually initializing the cache
+ and the model is split between different gpus. You can know which layers mapped to which device by
+ checking the associated device_map: `model.hf_device_map`.
+
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+
+ >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ StaticCache()
+ ```
+ """
+
+ is_compileable = True
+
+ def __init__(
+ self,
+ max_batch_size: int,
+ head_dim: int,
+ num_key_value_heads: int,
+ num_hidden_layers: int,
+ max_cache_len: int | None = None,
+ device: torch.device | str | None = None,
+ dtype: torch.dtype = torch.float32,
+ layer_device_map: dict[int, str | torch.device | int] | None = None,
+ ) -> None:
+ super().__init__()
+ self.max_batch_size = max_batch_size
+ self.max_cache_len = max_cache_len
+
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+ self.head_dim = head_dim
+ self._dtype = dtype
+ self.num_key_value_heads = num_key_value_heads
+ self.num_hidden_layers = num_hidden_layers
+ self.key_cache: list[torch.Tensor] = []
+ self.value_cache: list[torch.Tensor] = []
+ # Note: There will be significant perf decrease if switching to use 5D tensors instead.
+ cache_shape = (
+ self.max_batch_size,
+ self.num_key_value_heads,
+ self.max_cache_len,
+ self.head_dim,
+ )
+ device = torch.device(device) if device is not None else None
+ for idx in range(self.num_hidden_layers):
+ if layer_device_map is not None:
+ layer_device = layer_device_map[idx]
+ else:
+ layer_device = device
+ new_layer_key_cache = torch.zeros(
+ cache_shape, dtype=self._dtype, device=layer_device
+ )
+ new_layer_value_cache = torch.zeros(
+ cache_shape, dtype=self._dtype, device=layer_device
+ )
+ # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
+ # preventing compiled graph breaks when updating the cache.
+ torch._dynamo.mark_static_address(new_layer_key_cache)
+ torch._dynamo.mark_static_address(new_layer_value_cache)
+ self.key_cache.append(new_layer_key_cache)
+ self.value_cache.append(new_layer_value_cache)
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: dict[str, Any] | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+ It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
+
+ Parameters:
+ key_states (`torch.Tensor`):
+ The new key states to cache.
+ value_states (`torch.Tensor`):
+ The new value states to cache.
+ layer_idx (`int`):
+ The index of the layer to cache the states for.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
+ to know how where to write in the cache.
+
+ Return:
+ A tuple containing the updated key and value states.
+ """
+ if cache_kwargs is None:
+ cache_kwargs = {}
+
+ key_states = key_states.to(self.key_cache[layer_idx].dtype)
+ value_states = value_states.to(self.value_cache[layer_idx].dtype)
+ return _static_cache_update(
+ self.key_cache[layer_idx],
+ self.value_cache[layer_idx],
+ key_states,
+ value_states,
+ cache_kwargs.get("cache_position"),
+ )
+
+ def get_seq_length(self, layer_idx: int | None = 0) -> int:
+ """Returns the sequence length of the cached states that were seen by the model."""
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+ # limit the check to the first batch member and head dimension.
+ # TODO: deprecate this function in favor of `cache_position`
+ return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
+
+ def get_max_cache_shape(self) -> int | None:
+ return self.max_cache_len
+
+ def reset(self):
+ """Resets the cache values while preserving the objects"""
+ for layer_idx in range(len(self.key_cache)):
+ # In-place ops prevent breaking the static address
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+
+ def get_mask_sizes(
+ self, cache_position: torch.Tensor, layer_idx: int
+ ) -> tuple[int, int]:
+ """
+ Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
+ the given layer at `layer_idx`.
+ The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
+ for each layer.
+ """
+ kv_length = self.get_max_cache_shape()
+ return kv_length, 0
+
+
+class Cache:
+ """
+ A cache used for storing hidden states produced by flash linear attention models.
+
+ It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
+ """
+
+ is_compileable = True
+
+ def __init__(self, seen_tokens: int = 0) -> Cache:
+ super().__init__()
+
+ self.states: list[dict[str, Any]] = []
+
+ self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
+
+ def __getitem__(self, layer_idx: int) -> dict[str, Any]:
+ if layer_idx < len(self):
+ return self.states[layer_idx]
+ else:
+ raise KeyError(
+ f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}"
+ )
+
+ def __iter__(self):
+ for state in self.states:
+ yield state
+
+ def __len__(self):
+ return len(self.states)
+
+ def update(
+ self,
+ recurrent_state: torch.Tensor | None = None,
+ attn_state: tuple[torch.Tensor, torch.Tensor] | None = None,
+ conv_state: tuple[torch.Tensor] | None = None,
+ ffn_state: torch.Tensor | None = None,
+ layer_idx: int = 0,
+ offset: int | None = 1,
+ cache_kwargs: dict | None = None,
+ ):
+ """
+ Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.
+
+ Args:
+ recurrent_state (`torch.Tensor`, `optional`):
+ The new recurrent state to cache.
+ attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
+ The new attention key/value states to cache.
+ conv_state (`Tuple[torch.Tensor]`, `optional`):
+ The new convolution state to cache.
+ layer_idx (`int`, defaults to 0):
+ The index of the layer to cache the states for.
+ offset (`int`, `optional`, defaults to 1):
+ The number of new tokens being processed.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass.
+
+ Return:
+ Dictionary of the updated state.
+ """
+
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += offset
+
+ if attn_state is not None:
+ input_size = attn_state[0].shape[-2]
+ window_size = cache_kwargs.get("window_size", None)
+ if not isinstance(attn_state, Tuple) or len(attn_state) != 2:
+ raise ValueError(
+ "`attn_state` must be a tuple of two tensors for key/value states"
+ )
+ if len(self.states) <= layer_idx:
+ if attn_state is not None:
+ if window_size is not None and input_size > window_size:
+ attn_state = (
+ attn_state[0][..., -window_size:, :].contiguous(),
+ attn_state[1][..., -window_size:, :].contiguous(),
+ )
+ state = dict(
+ recurrent_state=recurrent_state,
+ attn_state=attn_state,
+ conv_state=conv_state,
+ ffn_state=ffn_state,
+ )
+ self.states.append(state)
+ else:
+ state = self.states[layer_idx]
+ if recurrent_state is not None:
+ state["recurrent_state"] = recurrent_state
+ if attn_state is not None:
+ key_state, value_state = state["attn_state"]
+ if window_size is not None and key_state.shape[-2] == window_size:
+ # DO NOT allocate new memory if the cache is full
+ # roll the key/value states to the left by `input_size`
+ key_state = key_state.roll(-input_size, -2)
+ value_state = value_state.roll(-input_size, -2)
+ # replace the last `input_size` tokens with the new key/value states
+ key_state[..., -input_size:, :] = attn_state[0]
+ value_state[..., -input_size:, :] = attn_state[1]
+ attn_state = (key_state, value_state)
+ else:
+ attn_state = (
+ torch.cat([key_state, attn_state[0]], -2),
+ torch.cat([value_state, attn_state[1]], -2),
+ )
+ state["attn_state"] = attn_state
+ if conv_state is not None:
+ state["conv_state"] = conv_state
+ if ffn_state is not None:
+ state["ffn_state"] = ffn_state
+
+ return state
+
+ def get_seq_length(self, layer_idx: int | None = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ if len(self.states) <= layer_idx:
+ return 0
+ return self._seen_tokens
+
+ def get_max_length(self) -> int | None:
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
+ return None
+
+ def to_legacy_cache(self) -> tuple:
+ return tuple(self.states)
+
+ @classmethod
+ @torch.compiler.disable
+ def from_legacy_cache(
+ cls, past_key_values: tuple | None = None, seen_tokens: int = 0
+ ) -> Cache:
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
+
+ cache = cls(seen_tokens)
+ if isinstance(past_key_values, list):
+ for layer_idx in range(len(past_key_values)):
+ cache.states.append(past_key_values[layer_idx])
+ return cache
diff --git a/tts/tts/model/cache_utils.py b/tts/tts/model/cache_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1aa3d8f27013be13e55b5752673745f2b223c5e
--- /dev/null
+++ b/tts/tts/model/cache_utils.py
@@ -0,0 +1,208 @@
+from __future__ import annotations
+
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import transformers
+from transformers.cache_utils import DynamicCache
+
+from layers.attention import precompute_freqs_cis
+
+
+class FLACache(transformers.cache_utils.Cache):
+ """
+ A cache used for storing hidden states produced by flash linear attention models.
+
+ It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`.
+ """
+
+ is_compileable = True
+
+ def __init__(
+ self,
+ head_dim: int | None = None,
+ max_seq_len: int | None = None,
+ num_states: int | None = None,
+ device: str = "cuda",
+ seen_tokens: int = 0,
+ ):
+ super().__init__()
+ if head_dim is not None and max_seq_len is not None:
+ self.freqs = precompute_freqs_cis(
+ torch.arange(max_seq_len, device=device), head_dim
+ )
+ self.states: List[Dict[str, Any]] = []
+ if num_states is not None:
+ self.states = [
+ dict(
+ recurrent_state=None,
+ attn_state=None,
+ conv_state=None,
+ short_conv_state=None,
+ ffn_state=None,
+ crossatt_state=None,
+ crossatt_weights=None,
+ )
+ for _ in range(num_states)
+ ]
+
+ self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
+
+ def __getitem__(self, layer_idx: int) -> Dict[str, Any]:
+ if layer_idx < len(self):
+ return self.states[layer_idx]
+ else:
+ raise KeyError(
+ f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}"
+ )
+
+ def __iter__(self):
+ for state in self.states:
+ yield state
+
+ def __len__(self):
+ return len(self.states)
+
+ def update(
+ self,
+ recurrent_state: torch.Tensor = None,
+ attn_state: Tuple[torch.Tensor, torch.Tensor] = None,
+ conv_state: Tuple[torch.Tensor] = None,
+ short_conv_state: Tuple[torch.Tensor] = None,
+ crossatt_state: Tuple[torch.Tensor] = None,
+ crossatt_weights: Tuple[torch.Tensor] = None,
+ ffn_state: torch.Tensor = None,
+ layer_idx: int = 0,
+ offset: Optional[int] = 1,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, Any]:
+ """
+ Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.
+
+ Args:
+ recurrent_state (`torch.Tensor`, `optional`):
+ The new recurrent state to cache.
+ attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
+ The new attention key/value states to cache.
+ conv_state (`Tuple[torch.Tensor]`, `optional`):
+ The new convolution state to cache.
+ layer_idx (`int`, defaults to 0):
+ The index of the layer to cache the states for.
+ offset (`int`, `optional`, defaults to 1):
+ The number of new tokens being processed.
+ cache_kwargs (`Dict[str, Any]`, `optional`):
+ Additional arguments for the cache subclass.
+
+ Return:
+ Dictionary of the updated state.
+ """
+
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += offset
+
+ if attn_state is not None:
+ input_size = attn_state[0].shape[-2]
+ window_size = (
+ cache_kwargs.get("window_size", None)
+ if cache_kwargs is not None
+ else None
+ )
+ if not isinstance(attn_state, Tuple) or len(attn_state) != 2:
+ raise ValueError(
+ "`attn_state` must be a tuple of two tensors for key/value states"
+ )
+ if len(self.states) <= layer_idx:
+ if attn_state is not None:
+ if window_size is not None and input_size > window_size:
+ attn_state = (
+ attn_state[0][..., -window_size:, :].contiguous(),
+ attn_state[1][..., -window_size:, :].contiguous(),
+ )
+ state = dict(
+ recurrent_state=recurrent_state,
+ attn_state=attn_state,
+ conv_state=conv_state,
+ short_conv_state=short_conv_state,
+ ffn_state=ffn_state,
+ crossatt_state=crossatt_state,
+ crossatt_weights=crossatt_weights,
+ )
+ self.states.append(state)
+ else:
+ state = self.states[layer_idx]
+ if recurrent_state is not None:
+ state["recurrent_state"] = recurrent_state
+ if crossatt_state is not None:
+ state["crossatt_state"] = crossatt_state
+ if crossatt_weights is not None:
+ if state["crossatt_weights"] is not None:
+ state["crossatt_weights"] = torch.cat(
+ (state["crossatt_weights"], crossatt_weights), dim=-2
+ )
+ else:
+ state["crossatt_weights"] = crossatt_weights
+ if attn_state is not None:
+ key_state, value_state = state["attn_state"]
+ if window_size is not None and key_state.shape[-2] == window_size:
+ # DO NOT allocate new memory if the cache is full
+ # roll the key/value states to the left by `input_size`
+ key_state = key_state.roll(-input_size, -2)
+ value_state = value_state.roll(-input_size, -2)
+ # replace the last `input_size` tokens with the new key/value states
+ key_state[..., -input_size:, :] = attn_state[0]
+ value_state[..., -input_size:, :] = attn_state[1]
+ attn_state = (key_state, value_state)
+ else:
+ attn_state = (
+ torch.cat([key_state, attn_state[0]], -2),
+ torch.cat([value_state, attn_state[1]], -2),
+ )
+ state["attn_state"] = attn_state
+ if conv_state is not None:
+ state["conv_state"] = conv_state
+ if short_conv_state is not None:
+ state["short_conv_state"] = short_conv_state
+ if ffn_state is not None:
+ state["ffn_state"] = ffn_state
+
+ return state
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ if len(self.states) <= layer_idx:
+ return 0
+ return self._seen_tokens
+
+ def get_max_length(self) -> Optional[int]:
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
+ return None
+
+ def to_legacy_cache(self) -> Tuple:
+ return tuple(self.states)
+
+ @classmethod
+ @torch.compiler.disable
+ def from_legacy_cache(
+ cls, past_key_values: Optional[Tuple] = None, seen_tokens: int = 0
+ ) -> Cache:
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
+
+ cache = cls(seen_tokens)
+ if isinstance(past_key_values, list):
+ for layer_idx in range(len(past_key_values)):
+ cache.states.append(past_key_values[layer_idx])
+ return cache
+
+
+class TransformerDecoderCache(FLACache):
+ def __init__(
+ self,
+ head_dim: int,
+ max_seq_len: int,
+ device: str,
+ ):
+ super().__init__()
+ self.freqs = precompute_freqs_cis(
+ torch.arange(max_seq_len, device=device), head_dim
+ )
diff --git a/tts/tts/model/config.py b/tts/tts/model/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f2d9da0840f5ffb37bc43cc05b0e7b5bd52282a
--- /dev/null
+++ b/tts/tts/model/config.py
@@ -0,0 +1,70 @@
+from dataclasses import dataclass, field
+from typing import Literal
+
+
+@dataclass
+class SimpleGLADecoderConfig:
+ name: str = "simple_gla"
+ dim: int = 512
+ use_short_conv = True
+ expand_k: float = 0.5
+ expand_v: float = 1.0
+ num_heads: int = 4
+ num_layers: int = 12
+ ffn_expansion_factor: int = 4
+ conv_layers: list[int] | None = field(default_factory=lambda: [0, 2, 4, 6, 8, 10])
+ blind_crossatt: bool = True
+ listen_read_crossatt: dict[int, list[Literal["listen", "read"]]] | None = None
+ crossatt_num_heads: int = 1
+ crossatt_dropout: float = 0.1
+ crossatt_layer_idx: list[int] = field(default_factory=lambda: [5])
+
+
+@dataclass
+class TransformerEncoderConfig:
+ name: str = "sa_transformer"
+ dim: int = 512
+ num_heads: int = 8
+ num_layers: int = 6
+ ffn_expansion_factor: int = 4
+
+
+@dataclass
+class TransformerDecoderConfig:
+ name: str = "sa_transformer"
+ dim: int = 512
+ num_heads: int = 8
+ num_layers: int = 12
+ conv_layers: list[int] | None = field(default_factory=lambda: [0, 2, 4, 6, 8, 10])
+ ffn_expansion_factor: int = 4
+ crossatt_dropout: float = 0.1
+ crossatt_num_heads: int = 1
+ crossatt_layer_idx: list[int] = field(default_factory=lambda: [5, 6])
+
+
+DecoderConfig = TransformerDecoderConfig | SimpleGLADecoderConfig
+EncoderConfig = TransformerEncoderConfig
+
+
+@dataclass
+class TTSConfig:
+ dim: int = 512
+ text_vocab_size: int = 256 + 3
+ audio_vocab_size: int = 4096 + 3
+ audio_embed_size: int = 16
+ audio_input_type: Literal["continuous", "discrete"] = "continuous"
+ diffusion_head_num_layers: int = 3
+ encoder_cfg: EncoderConfig = field(default_factory=TransformerEncoderConfig)
+ decoder_cfg: DecoderConfig = field(default_factory=TransformerDecoderConfig)
+ stop_prediction_head: bool = True
+ continuous_diffusion: bool = True
+
+
+@dataclass
+class QueryVCConfig:
+ dim: int = 512
+ semantic_dim: int = 512
+ num_layers: int = 12
+ lag: int = 4
+ audio_embed_size: int = 8
+ diffusion_head_num_layers: int = 3
diff --git a/tts/tts/model/prediction_head.py b/tts/tts/model/prediction_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..273add83d6a5bd417d1acb93479218682cd56f36
--- /dev/null
+++ b/tts/tts/model/prediction_head.py
@@ -0,0 +1,259 @@
+import torch
+from einops import rearrange
+from numpy import zeros_like
+from torch import nn
+from torchdyn.core import NeuralODE
+
+from layers.ffn import AdaLNFinalLayer, AdaLNMLP, GaussianFourierTimeEmbedding
+
+
+def net_factory(hidden_dim: int, num_layers: int) -> nn.Module:
+ pass
+
+
+import os
+import sys
+from contextlib import contextmanager
+
+import torch
+import torch.nn.functional as F
+
+
+@contextmanager
+def suppress_stdout():
+ original_stdout = sys.stdout
+ try:
+ sys.stdout = open(os.devnull, "w")
+ yield
+ finally:
+ sys.stdout.close()
+ sys.stdout = original_stdout
+
+
+def sample_from_logits(
+ logits: torch.Tensor,
+ temperature: float = 1.0,
+ top_k: int = 0,
+ top_p: float = 0.0,
+) -> torch.Tensor:
+ B, N, C = logits.shape
+ logits = logits / temperature
+
+ # Apply top-k
+ if top_k > 0:
+ top_k = min(top_k, C)
+ topk_values, _ = torch.topk(logits, top_k, dim=-1)
+ kth_value = topk_values[..., -1, None]
+ logits = torch.where(
+ logits < kth_value, torch.full_like(logits, float("-inf")), logits
+ )
+
+ # Apply top-p (nucleus) sampling
+ if top_p > 0.0 and top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+ probs = F.softmax(sorted_logits, dim=-1)
+ cumulative_probs = torch.cumsum(probs, dim=-1)
+
+ # Create mask for tokens to remove
+ cutoff_mask = cumulative_probs > top_p
+ cutoff_mask[..., 0] = 0 # Always keep at least one token
+ sorted_logits[cutoff_mask] = float("-inf")
+
+ # Map back to original logits shape
+ logits = torch.full_like(logits, float("-inf")).scatter(
+ -1, sorted_indices, sorted_logits
+ )
+
+ # Convert logits to probabilities
+ probs = F.softmax(logits, dim=-1)
+
+ # Sample
+ samples = torch.multinomial(probs.view(-1, C), num_samples=1).view(B, N)
+ return samples
+
+
+class LogitsHead(nn.Module):
+ def __init__(self, hidden_dim: int, vocab_size: int):
+ super().__init__()
+ self.logits_proj = nn.Linear(hidden_dim, vocab_size)
+
+ def forward(self, pre_logits):
+ return self.logits_proj(pre_logits)
+
+ def compute_loss(
+ self,
+ pre_logits: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ ):
+ logits = self(pre_logits)
+ if mask is not None:
+ flat_logits = logits[mask]
+ flat_target = target[mask]
+ else:
+ flat_logits = rearrange(logits, "b n l -> (b n) l")
+ flat_target = rearrange(target, "b n -> (b n)")
+
+ loss = nn.functional.cross_entropy(
+ flat_logits,
+ flat_target,
+ ignore_index=1,
+ )
+
+ return {"cross_entropy": loss}
+
+ def predict(self, x: torch.Tensor, *args, **kwargs):
+ return sample_from_logits(self(x), *args, **kwargs)
+
+
+class ContinuousHead(nn.Module):
+ def __init__(self, hidden_dim: int, feature_dim: int):
+ super().__init__()
+ self.continuous_head = nn.Linear(hidden_dim, feature_dim)
+
+ def forward(self, x: torch.Tensor):
+ return self.continuous_head(x)
+
+ def predict(self, x: torch.Tensor):
+ return self(x)
+
+ def compute_loss(
+ self, pre_logits: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
+ ):
+ if mask is not None:
+ pre_logits = pre_logits[mask]
+ target = target[mask]
+ return {"mse": nn.functional.mse_loss(self(pre_logits), target)}
+
+
+class DiffusionHead(nn.Module):
+ def __init__(
+ self,
+ hidden_dim: int,
+ feature_dim: int,
+ num_layers: int,
+ cond_dim: int | None = None,
+ ):
+ super().__init__()
+ cond_dim = cond_dim if cond_dim is not None else hidden_dim
+ self.feature_embed = nn.Linear(feature_dim, hidden_dim)
+ self.cond_embed = nn.Linear(cond_dim, hidden_dim)
+ self.time_embed = GaussianFourierTimeEmbedding(hidden_dim // 2)
+ self.adaln_mlp = nn.ModuleList(
+ [AdaLNMLP(hidden_dim) for _ in range(num_layers)]
+ )
+ self.adaln_final_layer = AdaLNFinalLayer(hidden_dim, feature_dim)
+ self.feature_dim = feature_dim
+
+ def forward(
+ self,
+ cond: torch.Tensor,
+ x: torch.Tensor,
+ t: torch.Tensor | None = None,
+ cond_drop_mask: torch.BoolTensor | None = None,
+ ):
+ cond = self.cond_embed(cond)
+ if cond_drop_mask is not None:
+ cond[cond_drop_mask] = 0.0
+ cond += self.time_embed(t)[:, None]
+ x = self.feature_embed(x)
+
+ for l in self.adaln_mlp:
+ x = l(x, cond)
+ y = self.adaln_final_layer(x, cond)
+
+ return y
+
+ def compute_loss(
+ self,
+ cond: torch.Tensor,
+ x1: torch.Tensor,
+ mask: torch.Tensor | None,
+ sigma: float = 1e-5,
+ t: torch.Tensor | None = None,
+ cfg_drop_rate: float = 0.1,
+ ):
+ if t is None:
+ t = torch.rand(cond.shape[0], device=cond.device)
+
+ x0 = torch.randn_like(x1, device=x1.device)
+
+ flow_target = x1 - (1 - sigma) * x0
+ alpha = (1 - (1 - sigma) * t).view(-1, 1, 1)
+ xt = alpha * x0 + t.view(-1, 1, 1) * x1
+
+ if self.training and cfg_drop_rate > 0.0:
+ cond_drop_mask = torch.rand(cond.shape[:2]) < cfg_drop_rate
+ else:
+ cond_drop_mask = None
+
+ flow_pred = self(cond, xt, t, cond_drop_mask=cond_drop_mask)
+
+ if mask is not None:
+ flow_pred = flow_pred[mask]
+ flow_target = flow_target[mask]
+
+ loss = nn.functional.mse_loss(flow_pred, flow_target)
+
+ return {"diffusion": loss}
+
+ def predict(
+ self,
+ pre_prediction: torch.Tensor,
+ solver: str = "euler",
+ sensitivity: str = "adjoint",
+ num_steps: int = 10,
+ cfg: float = 1.0,
+ **kwargs,
+ ):
+ if cfg == 1.0:
+
+ def solver_fn(t, Xt, *args, **kwargs):
+ return self(pre_prediction, Xt, t.unsqueeze(0))
+ else:
+
+ def solver_fn(t, Xt, *args, **kwargs):
+ cond_uncond = torch.cat(
+ (pre_prediction, torch.zeros_like(pre_prediction)),
+ dim=0,
+ )
+ cond_uncond = self(cond_uncond, Xt.repeat(2, 1, 1), t.unsqueeze(0))
+ cond, uncond = cond_uncond.chunk(2, dim=0)
+ return uncond + cfg * (cond - uncond)
+
+ # get rid of torchdyn warning
+ with suppress_stdout():
+ node_ = NeuralODE(solver_fn, solver=solver, sensitivity=sensitivity)
+ t_span = torch.linspace(0, 1, num_steps + 1, device=pre_prediction.device)
+ traj = node_.trajectory(
+ torch.randn(1, 1, self.feature_dim, device=pre_prediction.device),
+ t_span=t_span,
+ )
+ prediction = traj[-1]
+ return prediction
+
+
+class StopPredictionHead(nn.Module):
+ def __init__(self, dim: int, weight_loss: float = 1.0):
+ super().__init__()
+ self.proj = nn.Linear(dim, 1)
+ self.weight_loss = weight_loss
+
+ def forward(self, pre_prediction: torch.Tensor):
+ return torch.sigmoid(self.proj(pre_prediction))
+
+ def predict(self, pre_prediction: torch.Tensor):
+ return torch.sigmoid(self.proj(pre_prediction))
+
+ def compute_loss(
+ self,
+ pre_prediction: torch.Tensor,
+ target: torch.Tensor,
+ ):
+ logits = self.proj(pre_prediction)
+ bce = nn.functional.binary_cross_entropy_with_logits(
+ logits.squeeze(-1),
+ target,
+ weight=torch.ones(logits.shape[0], device=logits.device) * self.weight_loss,
+ )
+ return {"stop_bce": bce}
diff --git a/tts/tts/model/registry.py b/tts/tts/model/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a00c29717c2a7888a38cef182b1240068810762
--- /dev/null
+++ b/tts/tts/model/registry.py
@@ -0,0 +1,18 @@
+ENCODER_REGISTRY = {}
+DECODER_REGISTRY = {}
+
+
+def register_encoder(name):
+ def wrapper(cls):
+ ENCODER_REGISTRY[name] = cls
+ return cls
+
+ return wrapper
+
+
+def register_decoder(name):
+ def wrapper(cls):
+ DECODER_REGISTRY[name] = cls
+ return cls
+
+ return wrapper
diff --git a/tts/tts/model/shortconv.py b/tts/tts/model/shortconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc9ed8b14586fa441a27267ffff09717ede15ac7
--- /dev/null
+++ b/tts/tts/model/shortconv.py
@@ -0,0 +1,39 @@
+import torch
+from fla.modules import ShortConvolution
+from torch import nn
+
+from layers.ffn import SwiGLU
+from model.cache_utils import FLACache
+
+
+class ShortConvBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ kernel_size: int,
+ ffn_expansion_factor: int,
+ layer_idx: int,
+ use_fast_conv1d: bool = True,
+ ):
+ super().__init__()
+ self.tmix = ShortConvolution(dim, kernel_size, use_fast_conv1d=use_fast_conv1d)
+ self.cmix = SwiGLU(dim, ffn_expansion_factor)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.layer_idx = layer_idx
+
+ def forward(self, x: torch.Tensor, cache: FLACache | None = None, **kwargs):
+ last_state = None
+ if cache is not None and len(cache) > self.layer_idx:
+ last_state = cache[self.layer_idx]["short_conv_state"]
+ x, short_conv_state = self.tmix(
+ self.norm1(x), cache=last_state, output_final_state=cache is not None
+ )
+ if cache is not None:
+ cache.update(
+ short_conv_state=short_conv_state,
+ layer_idx=self.layer_idx,
+ offset=x.shape[1],
+ )
+ x = self.cmix(self.norm2(x))
+ return x
diff --git a/tts/tts/model/simple_gla.py b/tts/tts/model/simple_gla.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bbdb8dab7a28acccee93bd348791834977bfd4b
--- /dev/null
+++ b/tts/tts/model/simple_gla.py
@@ -0,0 +1,248 @@
+import os
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from fla.layers.simple_gla import SimpleGatedLinearAttention
+from fla.models.utils import Cache
+from torch import nn
+
+from layers.attention import (
+ BlindCrossAttention,
+ ConvPos,
+ CrossAttention,
+ ListenReadCrossAttention,
+ precompute_freqs_cis,
+)
+from layers.ffn import SwiGLU
+from model.cache_utils import FLACache
+from model.config import SimpleGLADecoderConfig
+from model.registry import register_decoder
+from model.shortconv import ShortConvBlock
+
+if "GRAD_CKPT" in os.environ:
+
+ def maybe_grad_ckpt(f):
+ def grad_ckpt_f(*args, **kwargs):
+ return torch.utils.checkpoint.checkpoint(
+ f, *args, **kwargs, use_reentrant=False
+ )
+
+ return grad_ckpt_f
+else:
+
+ def maybe_grad_ckpt(f):
+ return f
+
+
+class SimpleGLABlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ layer_idx: int,
+ expand_k: float,
+ expand_v: float,
+ use_short_conv: bool,
+ ffn_expansion_factor: int,
+ ):
+ super().__init__()
+ self.tmix = SimpleGatedLinearAttention(
+ hidden_size=dim,
+ num_heads=num_heads,
+ layer_idx=layer_idx,
+ )
+ self.cmix = SwiGLU(dim, ffn_expansion_factor)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ def forward(
+ self,
+ x,
+ freqs: torch.Tensor | None = None,
+ text_freqs: torch.Tensor | None = None,
+ cache: Cache | None = None,
+ ):
+ x = (
+ self.tmix(
+ self.norm1(x),
+ freqs=freqs,
+ text_freqs=text_freqs,
+ past_key_values=cache,
+ use_cache=cache is not None,
+ )[0]
+ + x
+ )
+ x = self.cmix(self.norm2(x)) + x
+ return x
+
+
+class DecoderBlockWithOptionalCrossAttention(nn.Module):
+ def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None):
+ super().__init__()
+
+ self.decoder_block = decoder_block
+ self.crossatt = crossatt
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ encoder_output: torch.Tensor | None = None,
+ freqs: torch.Tensor | None = None,
+ text_freqs: torch.Tensor | None = None,
+ cache: Cache | None = None,
+ selfatt_mask: torch.Tensor | None = None,
+ crossatt_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ x = self.decoder_block(
+ x,
+ freqs=freqs,
+ cache=cache,
+ )
+
+ if self.crossatt is not None:
+ x = x + self.crossatt(
+ x,
+ k=encoder_output,
+ text_freqs=text_freqs,
+ mask=crossatt_mask,
+ cache=cache,
+ )
+
+ return x
+
+
+@register_decoder("simple_gla")
+class SimpleGLADecoder(nn.Module):
+ config = SimpleGLADecoderConfig
+
+ def __init__(self, cfg: SimpleGLADecoderConfig):
+ super().__init__()
+
+ assert cfg.dim % cfg.num_heads == 0
+ assert cfg.blind_crossatt != (cfg.listen_read_crossatt is not None)
+
+ if cfg.listen_read_crossatt is not None:
+ self.text_freqs_embd = ConvPos(cfg.dim)
+ else:
+ self.text_freqs_embd = None
+ self.head_dim = cfg.dim // cfg.num_heads
+
+ def simple_gla_block(i):
+ conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers
+ if i in conv_layers:
+ return ShortConvBlock(
+ dim=cfg.dim,
+ kernel_size=4,
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
+ layer_idx=i,
+ use_fast_conv1d=True,
+ )
+
+ else:
+ return SimpleGLABlock(
+ dim=cfg.dim,
+ num_heads=cfg.num_heads,
+ layer_idx=i,
+ expand_k=cfg.expand_k,
+ expand_v=cfg.expand_v,
+ use_short_conv=cfg.use_short_conv,
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
+ )
+
+ def crossatt_block(i):
+ if cfg.blind_crossatt:
+ if i in cfg.crossatt_layer_idx:
+ return BlindCrossAttention(
+ cfg.dim,
+ cfg.dim,
+ cfg.dim,
+ simple_gla_block(cfg.num_layers),
+ pos_dim=cfg.dim,
+ dropout=cfg.crossatt_dropout,
+ pos_type="convolutional",
+ layer_idx=cfg.num_layers,
+ )
+ elif cfg.listen_read_crossatt is not None:
+ if i in cfg.listen_read_crossatt.keys():
+ return ListenReadCrossAttention(
+ cfg.dim,
+ cfg.dim,
+ cfg.dim,
+ cfg.listen_read_crossatt[i],
+ )
+
+ else:
+ if i in cfg.crossatt_layer_idx:
+ return CrossAttention(
+ dim=cfg.dim,
+ num_heads=cfg.crossatt_num_heads,
+ dropout=cfg.crossatt_dropout,
+ layer_idx=i,
+ )
+
+ self.decoder_layers = nn.ModuleList(
+ [
+ DecoderBlockWithOptionalCrossAttention(
+ simple_gla_block(i),
+ crossatt_block(i),
+ )
+ for i in range(cfg.num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ crossatt_mask: torch.Tensor | None = None,
+ cache: FLACache | None = None,
+ ):
+ x = decoder_input
+ if self.text_freqs_embd is not None:
+ text_freqs = torch.arange(encoder_output.shape[1], device=x.device)[None, :]
+ text_freqs = self.text_freqs_embd(text_freqs)
+ else:
+ text_freqs = None
+ for layer in self.decoder_layers:
+ x = maybe_grad_ckpt(layer)(
+ x,
+ encoder_output,
+ text_freqs=text_freqs,
+ cache=cache,
+ crossatt_mask=crossatt_mask,
+ )
+ return x
+
+ def init_cache(self, max_seq_len, device):
+ return FLACache(num_states=len(self.decoder_layers) + 1)
+
+ def _get_query(self, audio_inputs: torch.Tensor, layer_idx: int):
+ assert self.decoder_layers[layer_idx].crossatt is not None
+ x = audio_inputs
+ for _, layer in zip(range(layer_idx - 1), self.decoder_layers):
+ x = layer(x, None)
+ return self.decoder_layers[layer_idx].crossatt._query(x)
+
+ def prefill(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ cache: FLACache | None = None,
+ ):
+ return self(encoder_output, decoder_input, cache=cache)
+
+ def decode_one(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ cache: Cache,
+ ):
+ x = decoder_input
+ for layer in self.decoder_layers:
+ x = layer(
+ x,
+ encoder_output,
+ cache=cache,
+ )
+ return x
diff --git a/tts/tts/model/transformer.py b/tts/tts/model/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f101a53d23301621a9b0907bf2fc70a133eb9ffb
--- /dev/null
+++ b/tts/tts/model/transformer.py
@@ -0,0 +1,223 @@
+import os
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from torch import nn
+from transformers.cache_utils import DynamicCache
+
+from layers.attention import CrossAttention, SelfAttention, precompute_freqs_cis
+from layers.ffn import SwiGLU
+from model.cache_utils import FLACache, TransformerDecoderCache
+from model.config import TransformerDecoderConfig, TransformerEncoderConfig
+from model.registry import register_decoder, register_encoder
+from model.shortconv import ShortConvBlock
+
+if "GRAD_CKPT" in os.environ:
+
+ def maybe_grad_ckpt(f):
+ def grad_ckpt_f(*args, **kwargs):
+ return torch.utils.checkpoint.checkpoint(
+ f, *args, **kwargs, use_reentrant=False
+ )
+
+ return grad_ckpt_f
+else:
+
+ def maybe_grad_ckpt(f):
+ return f
+
+
+class TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ layer_idx: int,
+ ffn_expansion_factor: int,
+ is_causal: bool,
+ ):
+ super().__init__()
+ self.tmix = SelfAttention(dim, num_heads, layer_idx, is_causal=is_causal)
+ self.cmix = SwiGLU(dim, ffn_expansion_factor)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ def forward(
+ self,
+ x,
+ freqs: torch.Tensor | None = None,
+ cache: DynamicCache | None = None,
+ mask: torch.Tensor | None = None,
+ ):
+ x = self.tmix(self.norm1(x), freqs=freqs, cache=cache, mask=mask) + x
+ x = self.cmix(self.norm2(x)) + x
+ return x
+
+
+class DecoderBlockWithOptionalCrossAttention(nn.Module):
+ def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None):
+ super().__init__()
+
+ self.decoder_block = decoder_block
+ self.crossatt = crossatt
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ encoder_output: torch.Tensor | None = None,
+ freqs: torch.Tensor | None = None,
+ cache: FLACache | None = None,
+ selfatt_mask: torch.Tensor | None = None,
+ crossatt_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ x = self.decoder_block(
+ x,
+ freqs=freqs,
+ cache=cache,
+ )
+
+ if self.crossatt is not None:
+ x = x + self.crossatt(
+ x,
+ k=encoder_output,
+ mask=crossatt_mask,
+ cache=cache,
+ )
+
+ return x
+
+
+@register_decoder("sa_transformer")
+class TransformerDecoder(nn.Module):
+ def __init__(self, cfg: TransformerDecoderConfig):
+ super().__init__()
+
+ assert cfg.dim % cfg.num_heads == 0
+ self.head_dim = cfg.dim // cfg.num_heads
+
+ def transformer_block(i):
+ conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers
+ if i in conv_layers:
+ return ShortConvBlock(
+ dim=cfg.dim,
+ kernel_size=4,
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
+ layer_idx=i,
+ use_fast_conv1d=True,
+ )
+ else:
+ return TransformerBlock(
+ dim=cfg.dim,
+ num_heads=cfg.num_heads,
+ layer_idx=i,
+ ffn_expansion_factor=cfg.ffn_expansion_factor,
+ is_causal=True,
+ )
+
+ def crossatt_block(i):
+ return CrossAttention(
+ dim=cfg.dim,
+ num_heads=cfg.crossatt_num_heads,
+ dropout=cfg.crossatt_dropout,
+ layer_idx=i,
+ )
+
+ self.decoder_layers = nn.ModuleList(
+ [
+ DecoderBlockWithOptionalCrossAttention(
+ transformer_block(i),
+ crossatt_block(i) if i in cfg.crossatt_layer_idx else None,
+ )
+ for i in range(cfg.num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ crossatt_mask: torch.Tensor | None = None,
+ cache: FLACache | None = None,
+ ):
+ x = decoder_input
+ positions = torch.arange(x.shape[1], device=x.device)
+ freqs = precompute_freqs_cis(positions, self.head_dim)
+ for layer in self.decoder_layers:
+ x = maybe_grad_ckpt(layer)(
+ x,
+ encoder_output,
+ freqs=freqs,
+ crossatt_mask=crossatt_mask,
+ cache=cache,
+ )
+ return x
+
+ def init_cache(self, max_seq_len, device):
+ return FLACache(self.head_dim, max_seq_len * 2, device)
+
+ def prefill(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ cache: FLACache | None = None,
+ ):
+ return self(encoder_output, decoder_input, cache=cache)
+
+ def decode_one(
+ self,
+ encoder_output: torch.Tensor,
+ decoder_input: torch.Tensor,
+ cache: FLACache,
+ ):
+ x = decoder_input
+ pos = cache._seen_tokens
+ freq = cache.freqs[[pos]]
+ for layer in self.decoder_layers:
+ x = layer(
+ x,
+ encoder_output,
+ freqs=freq,
+ cache=cache,
+ )
+ return x
+
+
+@register_encoder("sa_transformer")
+class TransformerEncoder(nn.Module):
+ config = TransformerEncoderConfig
+
+ def __init__(self, cfg: TransformerEncoderConfig):
+ super().__init__()
+ assert cfg.dim % cfg.num_heads == 0
+ self.head_dim = cfg.dim // cfg.num_heads
+ self.encoder_layers = nn.ModuleList(
+ [
+ TransformerBlock(
+ cfg.dim,
+ cfg.num_heads,
+ i,
+ cfg.ffn_expansion_factor,
+ is_causal=False,
+ )
+ for i in range(cfg.num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ ):
+ positions = torch.arange(x.shape[1], device=x.device)
+ freqs = precompute_freqs_cis(positions, self.head_dim)
+ if mask is not None:
+ mask = rearrange(mask, "b n m -> b 1 n m")
+ mask = torch.logical_or(
+ mask,
+ rearrange(torch.eye(mask.shape[-1], device=x.device), "n m ->1 1 n m"),
+ )
+
+ for layer in self.encoder_layers:
+ x = layer(x, freqs=freqs, mask=mask)
+ return x
diff --git a/tts/tts/tools.py b/tts/tts/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cbd6784b7ec80424547ac5acb38f411a7b7ddaa
--- /dev/null
+++ b/tts/tts/tools.py
@@ -0,0 +1,77 @@
+from typing import Callable, List, Optional
+from itertools import accumulate
+
+import torch
+
+default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+def pad_2d_sequence(seq, padding_value=0):
+ max_x, max_y = map(max, zip(*map(lambda x: x.shape, seq)))
+ pad = lambda x: torch.nn.functional.pad(
+ x,
+ (0, max_y - x.shape[1], 0, max_x - x.shape[0]),
+ value=padding_value,
+ )
+ return torch.stack([pad(x) for x in seq])
+
+def packmask_2d(xlen: list[int], ylen: list[int], offset: int=0) -> torch.Tensor:
+ _, ybound = map(lambda x: [0] + list(accumulate(x, int.__add__)), (xlen, ylen))
+ lb, hb = [], []
+
+ for n, l, h in zip(xlen, ybound[:-1], ybound[1:]):
+ lb += [l]*n
+ hb += [h]*n
+
+ lb, hb = map(torch.tensor, (lb, hb))
+ if offset:
+ lb -= offset
+ hb += offset
+
+ rge = torch.arange(ybound[-1])
+
+ lm = rge.unsqueeze(0) >= lb.unsqueeze(1)
+ hm = rge.unsqueeze(0) < hb.unsqueeze(1)
+
+ return lm * hm
+
+
+def topk_sampling(seq, k=1, temp=1.):
+ topk = torch.topk(seq, k, dim=-1)
+ logits = seq / temp
+ mask = logits < topk.values[:, [-1]]
+ logits[mask] = -float('Inf')
+ probs = torch.softmax(logits, dim=-1)
+ return torch.multinomial(probs, num_samples=1)
+
+def delay_rvq(
+ code,
+ head_token: int = -2,
+ tail_token: int = -3,
+):
+ q, _ = code.shape
+ extension = torch.ones((q, q + 1)).tril() * head_token
+ extension += torch.ones((q + 1, q)).tril(diagonal=-1).T * tail_token
+ extension = torch.flip(extension, (1,))
+ extended_code = torch.cat((code, extension), axis=1)
+ for i in range(q):
+ extended_code[i, :] = torch.roll(extended_code[i, :], i + 1)
+
+ return extended_code.long()
+
+def undelay_rvq(extended_code):
+ q, _, n = extended_code.shape
+ out = []
+ for i in range(q):
+ out.append(torch.roll(extended_code[i], -(i + 1), dims=1))
+ out = torch.stack(out, dim=0)
+ return out[:, :, :-(q+1)]
+
+def sequence_mask(lengths, max_len=None, device=default_device):
+ batch_size = lengths.shape[0]
+ if max_len is None:
+ max_len = torch.max(lengths).item()
+
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
+ mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
+
+ return mask
diff --git a/tts/tts/train_tts.py b/tts/tts/train_tts.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f0d9a064c2d76048fdd7f78d55705213121a5f9
--- /dev/null
+++ b/tts/tts/train_tts.py
@@ -0,0 +1,197 @@
+import json
+import math
+from dataclasses import asdict
+from pathlib import Path
+
+import hydra
+import numpy as np
+import pytorch_lightning as ptl
+import torch
+from omegaconf import DictConfig
+from safetensors.torch import save_file
+from torch import nn
+from torch.optim.lr_scheduler import LambdaLR
+from transformers import get_cosine_schedule_with_warmup
+
+from model.config import TTSConfig
+from model.prediction_head import DiffusionHead
+from tts import ARTTSModel
+
+
+def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
+ def lr_lambda(step):
+ if step < warmup_steps:
+ return step / max(1, warmup_steps)
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
+ return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
+
+ return lr_lambda
+
+
+class TrainARTTS(ptl.LightningModule):
+ def __init__(
+ self,
+ config: TTSConfig,
+ quant_layer: list[int],
+ tie_embed: bool = False,
+ learning_rate: float = 5e-4,
+ end_learning_rate: float | None = None,
+ weight_decay: float = 0.1,
+ betas: tuple[float, float] = (0.9, 0.999),
+ n_warmup_steps: int = 500,
+ n_training_steps: int = 300000,
+ mask_text_p: float = 0.0,
+ load_weights: str | None = None,
+ stop_token_weight: float | None = None,
+ stop_loss_factor: float = 0.1,
+ stop_loss_warmup: tuple[int, int] | None = None,
+ ):
+ super(TrainARTTS, self).__init__()
+
+ self.learning_rate = learning_rate
+ self.weight_decay = weight_decay
+ self.betas = betas
+ self.n_warmup_steps = n_warmup_steps
+ self.n_training_steps = n_training_steps
+ self.stop_token_weight = stop_token_weight
+ self.stop_loss_factor = stop_loss_factor
+
+ self.save_hyperparameters()
+
+ self.model = ARTTSModel(config)
+
+ if load_weights is not None:
+ model = torch.load(load_weights)
+ self.load_state_dict(model["state_dict"], strict=False)
+
+ def on_train_epoch_start(self):
+ if hasattr(self.trainer.train_dataloader.batch_sampler, "set_epoch"):
+ self.trainer.train_dataloader.batch_sampler.set_epoch(self.current_epoch)
+
+ def save_model_weights_and_config(
+ self,
+ dir: str | None,
+ model_filename: str = "model.st",
+ config_filename: str = "config.json",
+ ):
+ cfg = self.hparams.config
+ Path(dir).mkdir(exist_ok=True)
+ model_path = Path(dir) / model_filename
+ save_file(self.model.state_dict(), model_path)
+ with open(Path(dir) / config_filename, "w") as f:
+ json.dump(asdict(cfg), f, indent=2)
+
+ def step(self, batch, batch_idx: int, validation: bool = False):
+ text_token = batch["text_token"]
+ audio_token = batch["audio_token"].squeeze(2)
+ crossatt_mask = batch["crossatt_mask"]
+ text_rel_pos = batch["text_rel_pos"]
+ encoder_mask = batch["encoder_mask"]
+ logits_mask = batch["y_mask"]
+
+ pre_logits = self.model(
+ text_token,
+ audio_token,
+ encoder_mask,
+ logits_mask,
+ crossatt_mask,
+ )
+ losses = {}
+ if validation and type(self.model.prediction_head) is DiffusionHead:
+ t = (
+ torch.ones(pre_logits.shape[0], device=pre_logits.device)
+ * batch_idx
+ / self.trainer.num_val_batches[0]
+ )
+ losses |= self.model.prediction_head.compute_loss(
+ pre_logits,
+ audio_token[:, 1:],
+ mask=logits_mask[:, 1:] if logits_mask is not None else None,
+ t=t,
+ )
+ else:
+ losses |= self.model.prediction_head.compute_loss(
+ pre_logits,
+ audio_token[:, 1:],
+ mask=logits_mask[:, 1:] if logits_mask is not None else None,
+ )
+
+ if self.model.stop_prediction_head is not None and logits_mask is not None:
+ target = nn.functional.pad((~logits_mask)[:, 2:].to(pre_logits), (0, 1))
+ mask = logits_mask[:, 1:]
+ losses |= self.model.stop_prediction_head.compute_loss(
+ pre_logits[mask],
+ target[mask],
+ )
+
+ return losses
+
+ def training_step(self, batch, idx):
+ losses = self.step(batch, idx)
+ total_loss = 0.0
+ for name, loss in losses.items():
+ self.log(f"train_{name}", loss, prog_bar=True, sync_dist=True)
+ if "stop" in name:
+ if self.hparams.stop_loss_warmup is not None:
+ alpha, beta = self.hparams.stop_loss_warmup
+ warmup = np.clip((idx - alpha) / beta, a_min=0.0, a_max=1.0)
+ else:
+ warmup = 1.0
+ loss *= self.stop_loss_factor * warmup
+ total_loss += loss
+ self.log("train_loss", total_loss, prog_bar=True, sync_dist=True)
+ return total_loss
+
+ def validation_step(self, batch, idx):
+ losses = self.step(batch, idx, validation=True)
+ total_loss = 0.0
+ for name, loss in losses.items():
+ self.log(f"val_{name}", loss, prog_bar=True, sync_dist=True)
+ total_loss += loss
+ self.log("val_loss", total_loss, prog_bar=True, sync_dist=True)
+ return total_loss
+
+ def configure_optimizers(self):
+ params = [
+ {
+ "params": self.model.parameters(),
+ "weight_decay": self.weight_decay,
+ }
+ ]
+ opt = torch.optim.AdamW(
+ params,
+ lr=self.learning_rate,
+ betas=self.betas,
+ )
+ # scheduler = get_cosine_schedule_with_warmup(
+ # opt,
+ # num_warmup_steps=self.n_warmup_steps,
+ # num_training_steps=self.n_training_steps,
+ # )
+ scheduler = LambdaLR(
+ opt,
+ lr_lambda=cosine_schedule_with_warmup(
+ warmup_steps=self.hparams.n_warmup_steps,
+ total_steps=self.hparams.n_training_steps,
+ start_lr=self.hparams.learning_rate,
+ end_lr=self.hparams.learning_rate * 0.1,
+ ),
+ )
+ return [opt], [{"scheduler": scheduler, "interval": "step"}]
+
+
+@hydra.main(config_path="hydra_configs/", config_name="config", version_base="1.3")
+def main(cfg: DictConfig):
+ ptl.seed_everything(cfg.seed_everything)
+
+ model = hydra.utils.instantiate(cfg.model)
+ cfg.experiment_name = f"ARTTS_{model.hparams.config.decoder_cfg.name}"
+ datamodule = hydra.utils.instantiate(cfg.data)
+ trainer = hydra.utils.instantiate(cfg.trainer)
+
+ trainer.fit(model, datamodule, ckpt_path=cfg.get("ckpt_path"))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tts/tts/tts.py b/tts/tts/tts.py
new file mode 100644
index 0000000000000000000000000000000000000000..425c070b749ff441916cb4f86cf2a89d4137b0f2
--- /dev/null
+++ b/tts/tts/tts.py
@@ -0,0 +1,158 @@
+import json
+from pathlib import Path
+from typing import Any
+
+import torch
+from safetensors import safe_open
+from torch import nn
+from tqdm import tqdm
+
+from model.config import TTSConfig
+from model.prediction_head import (
+ ContinuousHead,
+ DiffusionHead,
+ LogitsHead,
+ StopPredictionHead,
+)
+from model.registry import DECODER_REGISTRY, ENCODER_REGISTRY
+
+
+class ARTTSModel(nn.Module):
+ def __init__(self, cfg: TTSConfig):
+ super().__init__()
+ self.text_embd = nn.Embedding(cfg.text_vocab_size, cfg.dim)
+ if cfg.audio_input_type == "discrete":
+ self.audio_embd = nn.Embedding(cfg.audio_vocab_size, cfg.dim)
+ self.prediction_head = LogitsHead(cfg.decoder_cfg.dim, cfg.audio_vocab_size)
+ elif cfg.audio_input_type == "continuous" and cfg.continuous_diffusion:
+ self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim)
+ self.prediction_head = DiffusionHead(
+ cfg.decoder_cfg.dim,
+ cfg.audio_embed_size,
+ cfg.diffusion_head_num_layers,
+ )
+ elif cfg.audio_input_type == "continuous":
+ self.audio_embd = nn.Linear(cfg.audio_embed_size, cfg.dim)
+ self.prediction_head = ContinuousHead(
+ cfg.decoder_cfg.dim,
+ cfg.audio_embed_size,
+ )
+
+ self.text_encoder = ENCODER_REGISTRY[cfg.encoder_cfg.name](cfg.encoder_cfg)
+ self.audio_decoder = DECODER_REGISTRY[cfg.decoder_cfg.name](cfg.decoder_cfg)
+ self.stop_prediction_head = (
+ StopPredictionHead(cfg.dim) if cfg.stop_prediction_head else None
+ )
+
+ @classmethod
+ def from_pretrained_local(
+ cls,
+ path: str,
+ config_filename: str = "config.json",
+ model_filename: str = "model.st",
+ device: str = "cpu",
+ ):
+ with open(Path(path) / config_filename, "r") as f:
+ config = json.load(f)
+ for k in config.keys():
+ if k == "decoder_cfg":
+ config[k] = DECODER_REGISTRY[config[k]["name"]].config(**config[k])
+ if k == "encoder_cfg":
+ config[k] = ENCODER_REGISTRY[config[k]["name"]].config(**config[k])
+ config = TTSConfig(**config)
+ model = ARTTSModel(config)
+ state_dict = {}
+ with safe_open(Path(path) / model_filename, framework="pt", device=device) as f:
+ for k in f.keys():
+ state_dict[k] = f.get_tensor(k)
+ model.load_state_dict(state_dict)
+ return model
+
+ def _get_query(self, x: torch.Tensor, *args):
+ input_audio_embd = self.audio_embd(x)
+ return self.audio_decoder._get_query(input_audio_embd, *args)
+
+ def forward(
+ self,
+ text_ids: torch.LongTensor,
+ audio_inputs: torch.Tensor,
+ text_mask: torch.Tensor | None = None,
+ audio_mask: torch.Tensor | None = None,
+ crossatt_mask: torch.Tensor | None = None,
+ ):
+ input_text_embd = self.text_embd(text_ids)
+ input_audio_embd = self.audio_embd(audio_inputs[:, :-1])
+ text_hidden_states = self.text_encoder(input_text_embd, mask=text_mask)
+ pre_logits = self.audio_decoder(
+ text_hidden_states,
+ input_audio_embd,
+ crossatt_mask=crossatt_mask[:, :-1] if crossatt_mask is not None else None,
+ )
+ """
+ losses = {}
+ losses |= self.prediction_head.compute_loss(
+ pre_logits,
+ audio_inputs[:, 1:],
+ mask=audio_mask[:, 1:] if audio_mask is not None else None,
+ )
+ if self.stop_prediction_head is not None and audio_mask is not None:
+ target = nn.functional.pad((~audio_mask)[:, 2:].to(pre_logits), (0, 1))
+ mask = audio_mask[:, 1:]
+ losses |= self.stop_prediction_head.compute_loss(
+ pre_logits[mask],
+ target[mask],
+ )
+ return losses
+ """
+ return pre_logits
+
+ def generate(
+ self,
+ text_ids: torch.LongTensor,
+ prefix: torch.Tensor,
+ teacher_force: torch.Tensor | None = None,
+ max_seq_len: int = 200,
+ device: str = "cuda",
+ sampling_params: dict | None = None,
+ stop_threshold: float = 0.5,
+ ):
+ if sampling_params is None:
+ sampling_params = {}
+ if text_ids.ndim == 1:
+ text_ids = text_ids.unsqueeze(0)
+
+ input_text_embd = self.text_embd(text_ids)
+ text_hidden_states = self.text_encoder(input_text_embd)
+ prefix_embd = self.audio_embd(prefix)
+
+ cache = self.audio_decoder.init_cache(max_seq_len, device)
+ preds = []
+ pre_prediction = self.audio_decoder.prefill(
+ text_hidden_states,
+ prefix_embd,
+ cache=cache,
+ )
+ prediction = self.prediction_head.predict(pre_prediction[:, [-1]])
+ prediction_embd = self.audio_embd(prediction)
+ for i in tqdm(range(max_seq_len)):
+ pre_prediction = self.audio_decoder.decode_one(
+ input_text_embd,
+ prediction_embd,
+ cache,
+ )
+ prediction = self.prediction_head.predict(pre_prediction, **sampling_params)
+ prediction_embd = self.audio_embd(prediction)
+ if teacher_force is not None:
+ b, n, d = teacher_force.shape
+ if i < n:
+ prediction_embd = self.audio_embd(teacher_force[:, [i]])
+
+ preds.append(prediction)
+ if self.stop_prediction_head is not None:
+ stop_pred = self.stop_prediction_head(pre_prediction)
+ if stop_pred > stop_threshold:
+ break
+
+ full_prediction = torch.cat(preds, dim=1)
+
+ return cache, full_prediction