Spaces:
Running
on
Zero
Running
on
Zero
Mehdi Lakbar
commited on
Commit
·
56cfa73
1
Parent(s):
4905304
Initial demo of Lina-speech (pardi-speech)
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- codec/__init__.py +2 -0
- codec/__pycache__/__init__.cpython-312.pyc +0 -0
- codec/__pycache__/train_patchvae.cpython-312.pyc +0 -0
- codec/__pycache__/train_wavvae.cpython-312.pyc +0 -0
- codec/__pycache__/train_zflowae.cpython-312.pyc +0 -0
- codec/datamodules.py +249 -0
- codec/models/__init__.py +2 -0
- codec/models/__pycache__/__init__.cpython-312.pyc +0 -0
- codec/models/components/__init__.py +0 -0
- codec/models/components/__pycache__/__init__.cpython-312.pyc +0 -0
- codec/models/components/__pycache__/convnext.cpython-312.pyc +0 -0
- codec/models/components/convnext.py +221 -0
- codec/models/components/transformer.py +224 -0
- codec/models/pardi_tokenizer.py +10 -0
- codec/models/patchvae/__pycache__/model.cpython-312.pyc +0 -0
- codec/models/patchvae/__pycache__/modules.cpython-312.pyc +0 -0
- codec/models/patchvae/model.py +262 -0
- codec/models/patchvae/modules.py +396 -0
- codec/models/wavvae/__init__.py +0 -0
- codec/models/wavvae/__pycache__/__init__.cpython-312.pyc +0 -0
- codec/models/wavvae/__pycache__/discriminators.cpython-312.pyc +0 -0
- codec/models/wavvae/__pycache__/heads.cpython-312.pyc +0 -0
- codec/models/wavvae/__pycache__/layers.cpython-312.pyc +0 -0
- codec/models/wavvae/__pycache__/loss.cpython-312.pyc +0 -0
- codec/models/wavvae/__pycache__/model.cpython-312.pyc +0 -0
- codec/models/wavvae/__pycache__/modules.cpython-312.pyc +0 -0
- codec/models/wavvae/__pycache__/spectral_ops.cpython-312.pyc +0 -0
- codec/models/wavvae/dataset.py +84 -0
- codec/models/wavvae/discriminators.py +211 -0
- codec/models/wavvae/experiment.py +3 -0
- codec/models/wavvae/heads.py +194 -0
- codec/models/wavvae/helpers.py +71 -0
- codec/models/wavvae/layers.py +282 -0
- codec/models/wavvae/loss.py +142 -0
- codec/models/wavvae/model.py +140 -0
- codec/models/wavvae/modules.py +213 -0
- codec/models/wavvae/spectral_ops.py +192 -0
- codec/scripts/compare_codecs.py +441 -0
- codec/scripts/compare_wavvae.py +264 -0
- codec/scripts/compare_zcodec.py +312 -0
- codec/scripts/compute_stats.py +76 -0
- codec/scripts/compute_wer.py +48 -0
- codec/scripts/compute_wer_from_refs.py +64 -0
- codec/scripts/download_expresso.py +10 -0
- codec/scripts/download_gigaspeech.py +14 -0
- codec/scripts/download_lj.py +9 -0
- codec/scripts/download_ltts.py +16 -0
- codec/scripts/download_mlseng10k.py +13 -0
- codec/scripts/eval_asr.py +100 -0
- codec/scripts/eval_asr_from_filelist.py +60 -0
codec/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .train_patchvae import TrainPatchVAE
|
| 2 |
+
from .train_wavvae import TrainWavVAE
|
codec/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (260 Bytes). View file
|
|
|
codec/__pycache__/train_patchvae.cpython-312.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
codec/__pycache__/train_wavvae.cpython-312.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
codec/__pycache__/train_zflowae.cpython-312.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
codec/datamodules.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from functools import partial
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pytorch_lightning as ptl
|
| 10 |
+
import torch
|
| 11 |
+
import torchaudio
|
| 12 |
+
from safetensors.torch import safe_open
|
| 13 |
+
from sklearn.model_selection import train_test_split
|
| 14 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 15 |
+
from torch.utils.data import DataLoader, Dataset
|
| 16 |
+
|
| 17 |
+
from datasets import load_dataset, load_from_disk
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class WavVAEDataConfig:
|
| 22 |
+
filelist_path: str
|
| 23 |
+
sampling_rate: int
|
| 24 |
+
num_samples: int
|
| 25 |
+
batch_size: int
|
| 26 |
+
num_workers: int
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class WavVAEDataModule(ptl.LightningDataModule):
|
| 30 |
+
def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.train_config = train_params
|
| 33 |
+
self.val_config = val_params
|
| 34 |
+
|
| 35 |
+
def _get_dataloder(self, cfg: WavVAEDataConfig, train: bool):
|
| 36 |
+
dataset = WavVAEDataset(cfg, train=train)
|
| 37 |
+
dataloader = DataLoader(
|
| 38 |
+
dataset,
|
| 39 |
+
batch_size=cfg.batch_size,
|
| 40 |
+
num_workers=cfg.num_workers,
|
| 41 |
+
shuffle=train,
|
| 42 |
+
pin_memory=True,
|
| 43 |
+
)
|
| 44 |
+
return dataloader
|
| 45 |
+
|
| 46 |
+
def train_dataloader(self) -> DataLoader:
|
| 47 |
+
return self._get_dataloder(self.train_config, train=True)
|
| 48 |
+
|
| 49 |
+
def val_dataloader(self) -> DataLoader:
|
| 50 |
+
return self._get_dataloder(self.val_config, train=False)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class WavVAEDataset(Dataset):
|
| 54 |
+
def __init__(self, cfg: WavVAEDataConfig, train: bool):
|
| 55 |
+
with open(cfg.filelist_path) as f:
|
| 56 |
+
self.filelist = f.read().splitlines()
|
| 57 |
+
self.sampling_rate = cfg.sampling_rate
|
| 58 |
+
self.num_samples = cfg.num_samples
|
| 59 |
+
self.train = train
|
| 60 |
+
|
| 61 |
+
def __len__(self) -> int:
|
| 62 |
+
return len(self.filelist)
|
| 63 |
+
|
| 64 |
+
def __getitem__(self, index: int) -> torch.Tensor:
|
| 65 |
+
audio_path = self.filelist[index]
|
| 66 |
+
y, sr = torchaudio.load(audio_path)
|
| 67 |
+
if y.size(0) > 1:
|
| 68 |
+
# mix to mono
|
| 69 |
+
y = y.mean(dim=0, keepdim=True)
|
| 70 |
+
gain = np.random.uniform(-1, -6) if self.train else -3
|
| 71 |
+
y, _ = torchaudio.sox_effects.apply_effects_tensor(
|
| 72 |
+
y, sr, [["norm", f"{gain:.2f}"]]
|
| 73 |
+
)
|
| 74 |
+
if sr != self.sampling_rate:
|
| 75 |
+
y = torchaudio.functional.resample(
|
| 76 |
+
y, orig_freq=sr, new_freq=self.sampling_rate
|
| 77 |
+
)
|
| 78 |
+
if y.size(-1) < self.num_samples:
|
| 79 |
+
pad_length = self.num_samples - y.size(-1)
|
| 80 |
+
padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
|
| 81 |
+
y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
|
| 82 |
+
elif self.train:
|
| 83 |
+
start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
|
| 84 |
+
y = y[:, start : start + self.num_samples]
|
| 85 |
+
else:
|
| 86 |
+
# During validation, take always the first segment for determinism
|
| 87 |
+
y = y[:, : self.num_samples]
|
| 88 |
+
|
| 89 |
+
return y[0]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def pad_tensor_list_raw(
|
| 93 |
+
tensor_list: list[tuple[torch.Tensor, torch.Tensor]], pad_idx: int = 0
|
| 94 |
+
) -> dict[str, torch.Tensor | None]:
|
| 95 |
+
audio, hubert_maybe = zip(*tensor_list)
|
| 96 |
+
audio = torch.cat(audio, dim=0)
|
| 97 |
+
if hubert_maybe[0] is not None:
|
| 98 |
+
hubert_maybe = torch.stack(hubert_maybe, dim=0)
|
| 99 |
+
else:
|
| 100 |
+
hubert_maybe = None
|
| 101 |
+
return {"audio_z": audio, "hubert": hubert_maybe}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class SafeTensorDataset(Dataset):
|
| 105 |
+
"""
|
| 106 |
+
On __getitem__, opens the safetensor, uses get_slice() to inspect shape,
|
| 107 |
+
then either drops too-short files (return None) or returns a random subsequence slice.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
file_paths: list[str],
|
| 113 |
+
key: str,
|
| 114 |
+
hubert_path: str | None = None,
|
| 115 |
+
hubert_key: str = "layer_9",
|
| 116 |
+
min_length: int = 1,
|
| 117 |
+
subseq_length: int | None = None,
|
| 118 |
+
):
|
| 119 |
+
self.file_paths = file_paths
|
| 120 |
+
self.key = key
|
| 121 |
+
self.min_length = min_length
|
| 122 |
+
self.subseq_length = subseq_length
|
| 123 |
+
self.hubert_path = hubert_path
|
| 124 |
+
self.hubert_key = hubert_key
|
| 125 |
+
|
| 126 |
+
def __len__(self):
|
| 127 |
+
return len(self.file_paths)
|
| 128 |
+
|
| 129 |
+
def __getitem__(self, idx: int) -> torch.Tensor | None:
|
| 130 |
+
path = self.file_paths[idx]
|
| 131 |
+
# open file, get a slice wrapper for full tensor
|
| 132 |
+
with safe_open(path, framework="pt") as f:
|
| 133 |
+
tensor_slice = f.get_slice(self.key)
|
| 134 |
+
Q, N, D = tensor_slice.get_shape() # full shape [K, N]
|
| 135 |
+
|
| 136 |
+
# drop too-short
|
| 137 |
+
if N < self.min_length:
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
L = self.subseq_length or N
|
| 141 |
+
if L < N:
|
| 142 |
+
# sample random start
|
| 143 |
+
start = torch.randint(0, max(1, N - L - 1), ()).item()
|
| 144 |
+
start -= start % 2
|
| 145 |
+
# this yields a torch.Tensor of shape [K, L]
|
| 146 |
+
seq = tensor_slice[:, start : start + L]
|
| 147 |
+
else:
|
| 148 |
+
# full length
|
| 149 |
+
start = 0
|
| 150 |
+
seq = tensor_slice[:, :]
|
| 151 |
+
|
| 152 |
+
if self.hubert_path is not None:
|
| 153 |
+
path = Path(self.hubert_path) / Path(path).name
|
| 154 |
+
with safe_open(path, framework="pt") as f:
|
| 155 |
+
tensor_slice = f.get_slice(self.hubert_key)
|
| 156 |
+
hubert_N, hubert_D = tensor_slice.get_shape() # full shape [K, N]
|
| 157 |
+
seq_hubert = tensor_slice[start // 2 : start // 2 + L // 2]
|
| 158 |
+
return (seq, seq_hubert)
|
| 159 |
+
|
| 160 |
+
return (seq, None)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class SafeTensorDataModule(ptl.LightningDataModule):
|
| 164 |
+
"""
|
| 165 |
+
LightningDataModule using raw .safetensors file list + get_slice inside Dataset.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
train_file_list: str,
|
| 171 |
+
val_file_list: str | None = None,
|
| 172 |
+
hubert_path: str | None = None,
|
| 173 |
+
key: str = "audio_z",
|
| 174 |
+
hubert_key: str = "layer_9",
|
| 175 |
+
val_split: float = 0.1,
|
| 176 |
+
batch_size: int = 32,
|
| 177 |
+
num_workers: int = 4,
|
| 178 |
+
shuffle: bool = True,
|
| 179 |
+
seed: int = 1234,
|
| 180 |
+
min_length: int = 1,
|
| 181 |
+
subseq_length: int | None = None,
|
| 182 |
+
):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.train_file_list = train_file_list
|
| 185 |
+
self.val_file_list = val_file_list
|
| 186 |
+
self.hubert_path = hubert_path
|
| 187 |
+
self.key = key
|
| 188 |
+
self.val_split = val_split
|
| 189 |
+
self.batch_size = batch_size
|
| 190 |
+
self.num_workers = num_workers
|
| 191 |
+
self.shuffle = shuffle
|
| 192 |
+
self.seed = seed
|
| 193 |
+
self.min_length = min_length
|
| 194 |
+
self.subseq_length = subseq_length
|
| 195 |
+
|
| 196 |
+
def setup(self, stage=None):
|
| 197 |
+
with open(self.train_file_list, "r") as f:
|
| 198 |
+
train_paths = [line.strip() for line in f if line.strip()]
|
| 199 |
+
val_paths = None
|
| 200 |
+
if self.val_file_list is not None:
|
| 201 |
+
with open(self.train_file_list, "r") as f:
|
| 202 |
+
val_paths = [line.strip() for line in f if line.strip()]
|
| 203 |
+
# Split into train/val
|
| 204 |
+
if (
|
| 205 |
+
isinstance(self.val_split, float)
|
| 206 |
+
and 0 < self.val_split < 1
|
| 207 |
+
and val_paths is None
|
| 208 |
+
):
|
| 209 |
+
train_paths, val_paths = train_test_split(
|
| 210 |
+
train_paths, test_size=self.val_split, random_state=self.seed
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
self.train_ds = SafeTensorDataset(
|
| 214 |
+
train_paths,
|
| 215 |
+
key=self.key,
|
| 216 |
+
min_length=self.min_length,
|
| 217 |
+
subseq_length=self.subseq_length,
|
| 218 |
+
hubert_path=self.hubert_path,
|
| 219 |
+
)
|
| 220 |
+
self.val_ds = SafeTensorDataset(
|
| 221 |
+
val_paths,
|
| 222 |
+
key=self.key,
|
| 223 |
+
min_length=self.min_length,
|
| 224 |
+
subseq_length=self.subseq_length,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
def _collate_fn(
|
| 228 |
+
self, batch: list[torch.Tensor | None]
|
| 229 |
+
) -> tuple[torch.Tensor, torch.BoolTensor]:
|
| 230 |
+
seqs = [s for s in batch if s is not None]
|
| 231 |
+
return pad_tensor_list_raw(seqs, pad_idx=0)
|
| 232 |
+
|
| 233 |
+
def train_dataloader(self):
|
| 234 |
+
return DataLoader(
|
| 235 |
+
self.train_ds,
|
| 236 |
+
batch_size=self.batch_size,
|
| 237 |
+
shuffle=self.shuffle,
|
| 238 |
+
num_workers=self.num_workers,
|
| 239 |
+
collate_fn=self._collate_fn,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
def val_dataloader(self):
|
| 243 |
+
return DataLoader(
|
| 244 |
+
self.val_ds,
|
| 245 |
+
batch_size=self.batch_size,
|
| 246 |
+
shuffle=False,
|
| 247 |
+
num_workers=self.num_workers,
|
| 248 |
+
collate_fn=self._collate_fn,
|
| 249 |
+
)
|
codec/models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .patchvae.model import PatchVAE, PatchVAEConfig
|
| 2 |
+
from .wavvae.model import WavVAE, WavVAEConfig
|
codec/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (308 Bytes). View file
|
|
|
codec/models/components/__init__.py
ADDED
|
File without changes
|
codec/models/components/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (174 Bytes). View file
|
|
|
codec/models/components/__pycache__/convnext.cpython-312.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
codec/models/components/convnext.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ConvNeXtBlock(nn.Module):
|
| 6 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
dim (int): Number of input channels.
|
| 10 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
| 11 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 12 |
+
Defaults to None.
|
| 13 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 14 |
+
None means non-conditional LayerNorm. Defaults to None.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
intermediate_dim: int | None = None,
|
| 21 |
+
layer_scale_init_value: float = 0.0,
|
| 22 |
+
elementwise_affine_ln: bool = True,
|
| 23 |
+
is_causal: bool = False,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3
|
| 27 |
+
self.dwconv = nn.Conv1d(
|
| 28 |
+
dim, dim, kernel_size=7, padding=0 if is_causal else 3, groups=dim
|
| 29 |
+
) # depthwise conv
|
| 30 |
+
self.norm = nn.LayerNorm(
|
| 31 |
+
dim, eps=1e-6, elementwise_affine=elementwise_affine_ln
|
| 32 |
+
)
|
| 33 |
+
self.pwconv1 = nn.Linear(
|
| 34 |
+
dim, intermediate_dim
|
| 35 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 36 |
+
self.act = nn.GELU()
|
| 37 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 38 |
+
self.gamma = (
|
| 39 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
| 40 |
+
if layer_scale_init_value > 0
|
| 41 |
+
else None
|
| 42 |
+
)
|
| 43 |
+
self.is_causal = is_causal
|
| 44 |
+
|
| 45 |
+
def forward(
|
| 46 |
+
self,
|
| 47 |
+
x: torch.Tensor,
|
| 48 |
+
scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 49 |
+
gate: torch.Tensor | None = None,
|
| 50 |
+
) -> torch.Tensor:
|
| 51 |
+
residual = x
|
| 52 |
+
if self.is_causal:
|
| 53 |
+
x = torch.nn.functional.pad(x, (6, 0))
|
| 54 |
+
x = self.dwconv(x)
|
| 55 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
| 56 |
+
x = self.norm(x)
|
| 57 |
+
if scale_shift is not None:
|
| 58 |
+
scale, shift = scale_shift
|
| 59 |
+
x = x * scale[:, None] + shift[:, None]
|
| 60 |
+
x = self.pwconv1(x)
|
| 61 |
+
x = self.act(x)
|
| 62 |
+
x = self.pwconv2(x)
|
| 63 |
+
if self.gamma is not None:
|
| 64 |
+
x = self.gamma * x
|
| 65 |
+
if gate is not None:
|
| 66 |
+
x = gate[:, None] * x
|
| 67 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
| 68 |
+
|
| 69 |
+
x = residual + x
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class ConvNextNet(nn.Module):
|
| 74 |
+
def __init__(self, n_layers, dim, intermediate_dim: int | None = None):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.net = nn.Sequential(
|
| 77 |
+
*[
|
| 78 |
+
ConvNeXtBlock(
|
| 79 |
+
dim,
|
| 80 |
+
intermediate_dim,
|
| 81 |
+
)
|
| 82 |
+
for _ in range(n_layers)
|
| 83 |
+
]
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def forward(self, x):
|
| 87 |
+
return self.net(x)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ConvNextPatchEncoder(nn.Module):
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
patch_sizes: list[int],
|
| 94 |
+
n_layers_per_patch: int,
|
| 95 |
+
patch_expansion_factor: float = 1.5,
|
| 96 |
+
is_decoder: bool = False,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
patch_to_dim = []
|
| 100 |
+
convnext = []
|
| 101 |
+
for i, patch_size in enumerate(patch_sizes):
|
| 102 |
+
in_dim = int((patch_expansion_factor if i > 0 else 1.0) * patch_size)
|
| 103 |
+
out_dim = int(patch_expansion_factor * patch_size)
|
| 104 |
+
if is_decoder:
|
| 105 |
+
in_dim, out_dim = out_dim, in_dim
|
| 106 |
+
patch_to_dim.append(
|
| 107 |
+
nn.Linear(
|
| 108 |
+
in_dim,
|
| 109 |
+
out_dim,
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
convnext += [
|
| 113 |
+
nn.Sequential(
|
| 114 |
+
*[
|
| 115 |
+
ConvNeXtBlock(int(patch_size * patch_expansion_factor))
|
| 116 |
+
for _ in range(n_layers_per_patch)
|
| 117 |
+
]
|
| 118 |
+
)
|
| 119 |
+
]
|
| 120 |
+
self.is_decoder = is_decoder
|
| 121 |
+
self.patch_sizes = patch_sizes
|
| 122 |
+
self.patch_expansion_factor = patch_expansion_factor
|
| 123 |
+
self.patch_to_dim = nn.ModuleList(patch_to_dim)
|
| 124 |
+
self.convnext = nn.ModuleList(convnext)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
if self.is_decoder:
|
| 128 |
+
for i, patch_size in reversed(list(enumerate(self.patch_sizes))):
|
| 129 |
+
B, P, N = x.shape
|
| 130 |
+
patch_expansion_factor_maybe = (
|
| 131 |
+
self.patch_expansion_factor if i > 0 else 1.0
|
| 132 |
+
)
|
| 133 |
+
x = x.reshape(B, int(patch_size * self.patch_expansion_factor), -1)
|
| 134 |
+
x = self.convnext[i](x)
|
| 135 |
+
x = self.patch_to_dim[i](x.transpose(1, 2)).transpose(1, 2)
|
| 136 |
+
else:
|
| 137 |
+
for i, patch_size in enumerate(self.patch_sizes):
|
| 138 |
+
B, P, N = x.shape
|
| 139 |
+
patch_expansion_factor_maybe = (
|
| 140 |
+
self.patch_expansion_factor if i > 0 else 1.0
|
| 141 |
+
)
|
| 142 |
+
x = x.reshape(B, int(patch_size * patch_expansion_factor_maybe), -1)
|
| 143 |
+
x = self.patch_to_dim[i](x.transpose(1, 2)).transpose(1, 2)
|
| 144 |
+
x = self.convnext[i](x)
|
| 145 |
+
return x
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class ConvNextEncoder(nn.Module):
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
in_dim: int,
|
| 152 |
+
dim: int,
|
| 153 |
+
n_layers: int,
|
| 154 |
+
intermediate_dim: int | None = None,
|
| 155 |
+
stride: int = 1,
|
| 156 |
+
):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.in_proj = nn.Linear(in_dim, dim)
|
| 159 |
+
if stride > 1:
|
| 160 |
+
self.stride = nn.Conv1d(
|
| 161 |
+
in_channels=dim,
|
| 162 |
+
out_channels=dim,
|
| 163 |
+
kernel_size=(stride * 2) + 1,
|
| 164 |
+
stride=stride,
|
| 165 |
+
padding=stride // 2,
|
| 166 |
+
)
|
| 167 |
+
else:
|
| 168 |
+
self.stride = nn.Identity()
|
| 169 |
+
self.net = ConvNextNet(n_layers, dim, intermediate_dim)
|
| 170 |
+
|
| 171 |
+
def forward(self, x):
|
| 172 |
+
x = self.in_proj(x.transpose(1, 2)).transpose(1, 2)
|
| 173 |
+
x = self.stride(x)
|
| 174 |
+
return self.net(x)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class ConvNextDecoder(nn.Module):
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
out_dim: int,
|
| 181 |
+
dim: int,
|
| 182 |
+
n_layers: int,
|
| 183 |
+
intermediate_dim: int | None = None,
|
| 184 |
+
stride: int = 1,
|
| 185 |
+
stride_position: str = "before",
|
| 186 |
+
):
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.out_proj = nn.Linear(dim, out_dim)
|
| 189 |
+
if stride > 1:
|
| 190 |
+
self.stride = nn.ConvTranspose1d(
|
| 191 |
+
in_channels=dim,
|
| 192 |
+
out_channels=dim,
|
| 193 |
+
kernel_size=(stride * 2) + 1,
|
| 194 |
+
stride=stride,
|
| 195 |
+
padding=stride // 2,
|
| 196 |
+
output_padding=stride // 2,
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
self.stride = nn.Identity()
|
| 200 |
+
self.stride_position = stride_position
|
| 201 |
+
|
| 202 |
+
self.net = ConvNextNet(n_layers, dim, intermediate_dim)
|
| 203 |
+
|
| 204 |
+
def forward(self, x):
|
| 205 |
+
if self.stride_position == "before":
|
| 206 |
+
x = self.stride(x)
|
| 207 |
+
x = self.net(x)
|
| 208 |
+
if self.stride_position == "after":
|
| 209 |
+
x = self.stride(x)
|
| 210 |
+
return self.out_proj(x.transpose(1, 2)).transpose(1, 2)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class SwiGLU(nn.Module):
|
| 214 |
+
def __init__(self, d_model: int, ffn_expansion_factor: int = 4):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2)
|
| 217 |
+
self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model)
|
| 218 |
+
|
| 219 |
+
def forward(self, x):
|
| 220 |
+
gate, x = self.p_in(x).chunk(2, dim=-1)
|
| 221 |
+
return self.p_out(nn.functional.silu(gate) * x)
|
codec/models/components/transformer.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LocalSelfAttention(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
dim: int,
|
| 14 |
+
heads: int,
|
| 15 |
+
window_len: int = 32,
|
| 16 |
+
rotary: bool = True,
|
| 17 |
+
is_causal: bool = False,
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.heads = heads
|
| 21 |
+
assert dim % heads == 0, "dim must be divisible by heads"
|
| 22 |
+
self.qkv = nn.Linear(dim, 3 * dim)
|
| 23 |
+
self.o = nn.Linear(dim, dim)
|
| 24 |
+
self.rotary = RotaryEmbedding((dim // heads) // 2) if rotary else None
|
| 25 |
+
self.is_causal = is_causal
|
| 26 |
+
self.window_len = window_len
|
| 27 |
+
|
| 28 |
+
def forward(
|
| 29 |
+
self,
|
| 30 |
+
x: torch.Tensor,
|
| 31 |
+
mask: Optional[torch.Tensor] = None,
|
| 32 |
+
pos: Optional[torch.Tensor] = None,
|
| 33 |
+
cache: Optional[Dict[int, torch.Tensor]] = None,
|
| 34 |
+
layer_idx: Optional[int] = None,
|
| 35 |
+
time_step: int = 0,
|
| 36 |
+
) -> torch.Tensor:
|
| 37 |
+
# x: (batch, seq_len, dim)
|
| 38 |
+
b, n, dim = x.shape
|
| 39 |
+
b, t_len, hd = x.shape
|
| 40 |
+
pad_len = (self.window_len - t_len % self.window_len) % self.window_len
|
| 41 |
+
padded_x = torch.nn.functional.pad(x, (0, 0, 0, pad_len)) # pad on time dim
|
| 42 |
+
mask = torch.ones(t_len, dtype=torch.bool, device=x.device)
|
| 43 |
+
mask = torch.nn.functional.pad(
|
| 44 |
+
mask, (0, pad_len), value=False
|
| 45 |
+
) # False = masked
|
| 46 |
+
mask = mask.expand(b, -1) # [b, padded_len]
|
| 47 |
+
mask = rearrange(mask, "b (w n) -> b n 1 1 w", w=self.window_len)
|
| 48 |
+
qkv = self.qkv(padded_x).chunk(3, dim=-1)
|
| 49 |
+
q, k, v = [
|
| 50 |
+
rearrange(t, "b (w n) (h d) -> b n h w d", h=self.heads, w=self.window_len)
|
| 51 |
+
for t in qkv
|
| 52 |
+
]
|
| 53 |
+
if cache is not None:
|
| 54 |
+
assert layer_idx is not None, "layer_idx must be set when using cache"
|
| 55 |
+
cache[layer_idx]["k"] = torch.cat([cache[layer_idx]["k"], k], dim=2)
|
| 56 |
+
cache[layer_idx]["v"] = torch.cat([cache[layer_idx]["v"], v], dim=2)
|
| 57 |
+
k, v = cache[layer_idx]["k"], cache[layer_idx]["v"]
|
| 58 |
+
|
| 59 |
+
# apply rotary embeddings
|
| 60 |
+
if self.rotary is not None:
|
| 61 |
+
if pos is not None:
|
| 62 |
+
rot = self.rotary(pos) # (b,1,n,head_dim)
|
| 63 |
+
q = apply_rotary_emb(rot, q)
|
| 64 |
+
k = apply_rotary_emb(rot, k)
|
| 65 |
+
else:
|
| 66 |
+
q = self.rotary.rotate_queries_or_keys(q, offset=time_step)
|
| 67 |
+
k = self.rotary.rotate_queries_or_keys(k, offset=time_step)
|
| 68 |
+
|
| 69 |
+
# scaled dot-product attention
|
| 70 |
+
y = F.scaled_dot_product_attention(
|
| 71 |
+
q,
|
| 72 |
+
k,
|
| 73 |
+
v,
|
| 74 |
+
attn_mask=None if self.is_causal else mask,
|
| 75 |
+
is_causal=self.is_causal,
|
| 76 |
+
)
|
| 77 |
+
y = rearrange(y, "b n h w d -> b (w n) (h d)")
|
| 78 |
+
y = self.o(y)
|
| 79 |
+
y = y[:, :t_len]
|
| 80 |
+
return y
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class SelfAttention(nn.Module):
|
| 84 |
+
def __init__(
|
| 85 |
+
self, dim: int, heads: int, rotary: bool = True, is_causal: bool = False
|
| 86 |
+
):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.heads = heads
|
| 89 |
+
assert dim % heads == 0, "dim must be divisible by heads"
|
| 90 |
+
self.qkv = nn.Linear(dim, 3 * dim)
|
| 91 |
+
self.o = nn.Linear(dim, dim)
|
| 92 |
+
self.rotary = RotaryEmbedding((dim // heads) // 2) if rotary else None
|
| 93 |
+
self.is_causal = is_causal
|
| 94 |
+
|
| 95 |
+
def forward(
|
| 96 |
+
self,
|
| 97 |
+
x: torch.Tensor,
|
| 98 |
+
mask: Optional[torch.Tensor] = None,
|
| 99 |
+
pos: Optional[torch.Tensor] = None,
|
| 100 |
+
cache: Optional[Dict[int, torch.Tensor]] = None,
|
| 101 |
+
layer_idx: Optional[int] = None,
|
| 102 |
+
time_step: int = 0,
|
| 103 |
+
) -> torch.Tensor:
|
| 104 |
+
# x: (batch, seq_len, dim)
|
| 105 |
+
b, n, dim = x.shape
|
| 106 |
+
b, t_len, hd = x.shape
|
| 107 |
+
pad_len = (32 - t_len % 32) % 32
|
| 108 |
+
padded_x = torch.nn.functional.pad(x, (0, 0, 0, pad_len)) # pad on time dim
|
| 109 |
+
mask = torch.ones(t_len, dtype=torch.bool, device=x.device)
|
| 110 |
+
mask = torch.nn.functional.pad(
|
| 111 |
+
mask, (0, pad_len), value=False
|
| 112 |
+
) # False = masked
|
| 113 |
+
mask = mask.expand(b, -1) # [b, padded_len]
|
| 114 |
+
mask = rearrange(mask, "b (w n) -> b n 1 1 w", w=32)
|
| 115 |
+
qkv = self.qkv(padded_x).chunk(3, dim=-1)
|
| 116 |
+
q, k, v = [
|
| 117 |
+
rearrange(t, "b (w n) (h d) -> b n h w d", h=self.heads, w=32) for t in qkv
|
| 118 |
+
]
|
| 119 |
+
# caching for fast autoregressive
|
| 120 |
+
if cache is not None:
|
| 121 |
+
assert layer_idx is not None, "layer_idx must be set when using cache"
|
| 122 |
+
# append new keys/values
|
| 123 |
+
cache[layer_idx]["k"] = torch.cat([cache[layer_idx]["k"], k], dim=2)
|
| 124 |
+
cache[layer_idx]["v"] = torch.cat([cache[layer_idx]["v"], v], dim=2)
|
| 125 |
+
k, v = cache[layer_idx]["k"], cache[layer_idx]["v"]
|
| 126 |
+
|
| 127 |
+
# apply rotary embeddings
|
| 128 |
+
if self.rotary is not None:
|
| 129 |
+
if pos is not None:
|
| 130 |
+
rot = self.rotary(pos) # .unsqueeze(1) # (b,1,n,head_dim)
|
| 131 |
+
q = apply_rotary_emb(rot, q)
|
| 132 |
+
k = apply_rotary_emb(rot, k)
|
| 133 |
+
else:
|
| 134 |
+
q = self.rotary.rotate_queries_or_keys(q, offset=time_step)
|
| 135 |
+
k = self.rotary.rotate_queries_or_keys(k, offset=time_step)
|
| 136 |
+
|
| 137 |
+
# scaled dot-product attention
|
| 138 |
+
y = F.scaled_dot_product_attention(
|
| 139 |
+
q,
|
| 140 |
+
k,
|
| 141 |
+
v,
|
| 142 |
+
attn_mask=None if self.is_causal else mask,
|
| 143 |
+
is_causal=self.is_causal,
|
| 144 |
+
)
|
| 145 |
+
y = rearrange(y, "b n h w d -> b (w n) (h d)")
|
| 146 |
+
y = self.o(y)
|
| 147 |
+
y = y[:, :t_len]
|
| 148 |
+
return y
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class SwiGLU(nn.Module):
|
| 152 |
+
def __init__(self, d_model: int):
|
| 153 |
+
super().__init__()
|
| 154 |
+
hidden = d_model * 4 // 3
|
| 155 |
+
self.p_in = nn.Linear(d_model, hidden * 2)
|
| 156 |
+
self.p_out = nn.Linear(hidden, d_model)
|
| 157 |
+
|
| 158 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 159 |
+
gate, data = self.p_in(x).chunk(2, dim=-1)
|
| 160 |
+
return self.p_out(F.silu(gate) * data)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class TransformerBlock(nn.Module):
|
| 164 |
+
"""
|
| 165 |
+
Transformer block using custom SelfAttention and SwiGLU FFN.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
dim: embedding dimension
|
| 169 |
+
heads: number of attention heads
|
| 170 |
+
rotary: whether to use rotary embeddings
|
| 171 |
+
is_causal: whether to apply causal masking
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
dim: int,
|
| 177 |
+
head_size: int,
|
| 178 |
+
rotary: bool = True,
|
| 179 |
+
is_causal: bool = False,
|
| 180 |
+
elementwise_affine_ln: bool = True,
|
| 181 |
+
):
|
| 182 |
+
super().__init__()
|
| 183 |
+
assert dim % head_size == 0
|
| 184 |
+
heads = dim // head_size
|
| 185 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine_ln)
|
| 186 |
+
self.attn = LocalSelfAttention(dim, heads, rotary=rotary, is_causal=is_causal)
|
| 187 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine_ln)
|
| 188 |
+
self.ffn = SwiGLU(dim)
|
| 189 |
+
|
| 190 |
+
def forward(
|
| 191 |
+
self,
|
| 192 |
+
x: torch.Tensor,
|
| 193 |
+
mask: Optional[torch.Tensor] = None,
|
| 194 |
+
pos: Optional[torch.Tensor] = None,
|
| 195 |
+
cache: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
|
| 196 |
+
layer_idx: Optional[int] = None,
|
| 197 |
+
scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 198 |
+
gate: torch.Tensor = None,
|
| 199 |
+
time_step: int = 0,
|
| 200 |
+
) -> torch.Tensor:
|
| 201 |
+
# Self-attention block
|
| 202 |
+
norm1_x = self.norm1(x)
|
| 203 |
+
if scale_shift is not None:
|
| 204 |
+
scale, shift = scale_shift
|
| 205 |
+
norm1_x = norm1_x * scale[:, None] + shift[:, None]
|
| 206 |
+
|
| 207 |
+
attn_out = self.attn(
|
| 208 |
+
norm1_x,
|
| 209 |
+
mask=mask,
|
| 210 |
+
pos=pos,
|
| 211 |
+
cache=cache,
|
| 212 |
+
layer_idx=layer_idx,
|
| 213 |
+
time_step=time_step,
|
| 214 |
+
)
|
| 215 |
+
x = x + attn_out
|
| 216 |
+
|
| 217 |
+
norm2_x = self.norm2(x)
|
| 218 |
+
if gate is not None:
|
| 219 |
+
norm2_x = gate[:, None] * norm2_x
|
| 220 |
+
|
| 221 |
+
# Feedforward block
|
| 222 |
+
ffn_out = self.ffn(norm2_x)
|
| 223 |
+
x = x + ffn_out
|
| 224 |
+
return x
|
codec/models/pardi_tokenizer.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from zcodec.models import WavVAE, ZFlowAutoEncoder
|
| 4 |
+
from zcodec.models.wavvae.model import WavVAEConfig
|
| 5 |
+
from zcodec.models.zflowae.model import ZFlowAutoEncoderConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PardiTokenizer(nn.Module):
|
| 9 |
+
def __init__(self, wavvae_cfg: WavVAEConfig, zflowae_cfg: ZFlowAutoEncoderConfig):
|
| 10 |
+
|
codec/models/patchvae/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
codec/models/patchvae/__pycache__/modules.cpython-312.pyc
ADDED
|
Binary file (20.8 kB). View file
|
|
|
codec/models/patchvae/model.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from safetensors.torch import load_file
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torchdyn.core import NeuralODE
|
| 13 |
+
|
| 14 |
+
from .modules import AdaLNFlowPredictor, AutoEncoder
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@contextmanager
|
| 18 |
+
def suppress_stdout():
|
| 19 |
+
original_stdout = sys.stdout
|
| 20 |
+
try:
|
| 21 |
+
sys.stdout = open(os.devnull, "w")
|
| 22 |
+
yield
|
| 23 |
+
finally:
|
| 24 |
+
sys.stdout.close()
|
| 25 |
+
sys.stdout = original_stdout
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
|
| 29 |
+
def lr_lambda(step):
|
| 30 |
+
if step < warmup_steps:
|
| 31 |
+
return step / max(1, warmup_steps)
|
| 32 |
+
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
| 33 |
+
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
|
| 34 |
+
return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
|
| 35 |
+
|
| 36 |
+
return lr_lambda
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class PatchVAEConfig:
|
| 41 |
+
latent_dim: int
|
| 42 |
+
hidden_dim: int
|
| 43 |
+
latent_scaling: tuple[list[float], list[float]] | None
|
| 44 |
+
flow_factory: str
|
| 45 |
+
num_flow_layers: int
|
| 46 |
+
autoencoder_factory: str
|
| 47 |
+
num_autoencoder_layers: int
|
| 48 |
+
convnextformer_num_conv_per_transformer: int = 3
|
| 49 |
+
wavvae_path: str | None = None
|
| 50 |
+
fsq_levels: list[int] | None = None
|
| 51 |
+
bottleneck_size: int | None = None
|
| 52 |
+
latent_stride: int = 2
|
| 53 |
+
vae: bool = False
|
| 54 |
+
causal_transformer: bool = False
|
| 55 |
+
cond_dim: int | None = None
|
| 56 |
+
is_causal: bool = False
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class PatchVAE(nn.Module):
|
| 60 |
+
def __init__(self, cfg: PatchVAEConfig):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.flow_net = AdaLNFlowPredictor(
|
| 63 |
+
feat_dim=cfg.latent_dim * cfg.latent_stride,
|
| 64 |
+
dim=cfg.hidden_dim,
|
| 65 |
+
n_layer=cfg.num_flow_layers,
|
| 66 |
+
layer_factory=cfg.flow_factory,
|
| 67 |
+
cond_dim=cfg.cond_dim,
|
| 68 |
+
is_causal=cfg.is_causal,
|
| 69 |
+
)
|
| 70 |
+
self.autoencoder = AutoEncoder(
|
| 71 |
+
cfg.latent_dim * cfg.latent_stride,
|
| 72 |
+
cfg.hidden_dim,
|
| 73 |
+
cfg.num_autoencoder_layers,
|
| 74 |
+
cfg.autoencoder_factory,
|
| 75 |
+
out_dim=cfg.cond_dim,
|
| 76 |
+
vae=cfg.vae,
|
| 77 |
+
bottleneck_size=cfg.bottleneck_size,
|
| 78 |
+
convnextformer_num_conv_per_transformer=cfg.convnextformer_num_conv_per_transformer,
|
| 79 |
+
is_causal=cfg.is_causal,
|
| 80 |
+
)
|
| 81 |
+
if cfg.latent_scaling is not None:
|
| 82 |
+
mean, std = cfg.latent_scaling
|
| 83 |
+
self.register_buffer("mean_latent_scaling", torch.tensor(mean))
|
| 84 |
+
self.register_buffer("std_latent_scaling", torch.tensor(std))
|
| 85 |
+
else:
|
| 86 |
+
self.mean_latent_scaling = None
|
| 87 |
+
self.std_latent_scaling = None
|
| 88 |
+
|
| 89 |
+
self.latent_stride = cfg.latent_stride
|
| 90 |
+
self.latent_dim = cfg.latent_dim
|
| 91 |
+
self.wavvae = None
|
| 92 |
+
|
| 93 |
+
@classmethod
|
| 94 |
+
def from_pretrained(
|
| 95 |
+
cls,
|
| 96 |
+
pretrained_model_name_or_path: str,
|
| 97 |
+
map_location: str = "cpu",
|
| 98 |
+
):
|
| 99 |
+
if Path(pretrained_model_name_or_path).exists():
|
| 100 |
+
path = pretrained_model_name_or_path
|
| 101 |
+
else:
|
| 102 |
+
from huggingface_hub import snapshot_download
|
| 103 |
+
|
| 104 |
+
path = snapshot_download(pretrained_model_name_or_path)
|
| 105 |
+
|
| 106 |
+
with open(Path(path) / "config.json", "r") as f:
|
| 107 |
+
config = json.load(f)
|
| 108 |
+
config = PatchVAEConfig(**config)
|
| 109 |
+
model = cls(config).to(map_location)
|
| 110 |
+
state_dict = load_file(
|
| 111 |
+
Path(path) / "model.st",
|
| 112 |
+
device=map_location,
|
| 113 |
+
)
|
| 114 |
+
model.load_state_dict(state_dict, assign=True)
|
| 115 |
+
if config.wavvae_path is not None:
|
| 116 |
+
from .. import WavVAE
|
| 117 |
+
|
| 118 |
+
model.wavvae = WavVAE.from_pretrained(config.wavvae_path).to(map_location)
|
| 119 |
+
else:
|
| 120 |
+
model.wavvae = None
|
| 121 |
+
|
| 122 |
+
return model
|
| 123 |
+
|
| 124 |
+
def wavvae_from_pretrained(
|
| 125 |
+
self,
|
| 126 |
+
pretrained_model_name_or_path: str,
|
| 127 |
+
*args,
|
| 128 |
+
**kwargs,
|
| 129 |
+
):
|
| 130 |
+
from .. import WavVAE
|
| 131 |
+
|
| 132 |
+
self.wavvae = WavVAE.from_pretrained(
|
| 133 |
+
pretrained_model_name_or_path,
|
| 134 |
+
*args,
|
| 135 |
+
**kwargs,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def encode(self, wav: torch.Tensor):
|
| 139 |
+
assert self.wavvae is not None, (
|
| 140 |
+
"please provide WavVAE model to encode from waveform"
|
| 141 |
+
)
|
| 142 |
+
z = self.wavvae.encode(wav)
|
| 143 |
+
zz = self.encode_patch(z)
|
| 144 |
+
return zz
|
| 145 |
+
|
| 146 |
+
def decode(self, patchvae_latent: torch.Tensor, **kwargs):
|
| 147 |
+
assert self.wavvae is not None, (
|
| 148 |
+
"please provide WavVAE model to decode to waveform"
|
| 149 |
+
)
|
| 150 |
+
z = self.decode_patch(patchvae_latent, **kwargs)
|
| 151 |
+
wav = self.wavvae.decode(z)
|
| 152 |
+
return wav
|
| 153 |
+
|
| 154 |
+
def normalize_z(self, z: torch.Tensor):
|
| 155 |
+
if self.mean_latent_scaling is not None:
|
| 156 |
+
z = (z - self.mean_latent_scaling) / self.std_latent_scaling
|
| 157 |
+
return z
|
| 158 |
+
|
| 159 |
+
def denormalize_z(self, z: torch.Tensor):
|
| 160 |
+
if self.std_latent_scaling is not None:
|
| 161 |
+
z = z * self.std_latent_scaling + self.mean_latent_scaling
|
| 162 |
+
return z
|
| 163 |
+
|
| 164 |
+
def encode_patch(self, z: torch.Tensor, deterministic: bool = False):
|
| 165 |
+
B, T, D = z.shape
|
| 166 |
+
z = self.normalize_z(z)
|
| 167 |
+
if self.latent_stride > 1:
|
| 168 |
+
z = z[:, : T - T % self.latent_stride]
|
| 169 |
+
z = z.reshape(B, T // self.latent_stride, D * self.latent_stride)
|
| 170 |
+
return self.autoencoder.encode(z, deterministic=deterministic)
|
| 171 |
+
|
| 172 |
+
def decode_patch(
|
| 173 |
+
self,
|
| 174 |
+
latent: torch.Tensor,
|
| 175 |
+
cfg: float = 2.0,
|
| 176 |
+
num_steps: int = 15,
|
| 177 |
+
solver: str = "euler",
|
| 178 |
+
sensitivity: str = "adjoint",
|
| 179 |
+
temperature: float = 1.0,
|
| 180 |
+
**kwargs,
|
| 181 |
+
):
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
z_cond = self.autoencoder.decode(latent).transpose(1, 2)
|
| 184 |
+
if cfg == 1.0:
|
| 185 |
+
|
| 186 |
+
def solver_fn(t, Xt, *args, **kwargs):
|
| 187 |
+
flow = self.flow_net(Xt, z_cond, t.unsqueeze(0))
|
| 188 |
+
return flow
|
| 189 |
+
else:
|
| 190 |
+
z_cond_uncond = torch.cat((z_cond, torch.zeros_like(z_cond)), dim=0)
|
| 191 |
+
|
| 192 |
+
def solver_fn(t, Xt, *args, **kwargs):
|
| 193 |
+
flow = self.flow_net(
|
| 194 |
+
Xt.repeat(2, 1, 1), z_cond_uncond, t.unsqueeze(0)
|
| 195 |
+
)
|
| 196 |
+
cond, uncond = flow.chunk(2, dim=0)
|
| 197 |
+
|
| 198 |
+
return uncond + cfg * (cond - uncond)
|
| 199 |
+
|
| 200 |
+
with suppress_stdout():
|
| 201 |
+
node_ = NeuralODE(
|
| 202 |
+
solver_fn,
|
| 203 |
+
solver=solver,
|
| 204 |
+
sensitivity=sensitivity,
|
| 205 |
+
**kwargs,
|
| 206 |
+
)
|
| 207 |
+
t_span = torch.linspace(0, 1, num_steps + 1, device=z_cond.device)
|
| 208 |
+
patch_dim = self.latent_dim * self.latent_stride
|
| 209 |
+
x0 = torch.randn(
|
| 210 |
+
z_cond.shape[0],
|
| 211 |
+
patch_dim,
|
| 212 |
+
z_cond.shape[2],
|
| 213 |
+
device=z_cond.device,
|
| 214 |
+
)
|
| 215 |
+
traj = node_.trajectory(
|
| 216 |
+
x0 * temperature,
|
| 217 |
+
t_span=t_span,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
y_hat = traj[-1]
|
| 221 |
+
y_hat = y_hat.transpose(1, 2)
|
| 222 |
+
B, T, D = y_hat.shape
|
| 223 |
+
y_hat = y_hat.reshape(B, T * self.latent_stride, D // self.latent_stride)
|
| 224 |
+
y_hat = self.denormalize_z(y_hat)
|
| 225 |
+
return y_hat
|
| 226 |
+
|
| 227 |
+
def forward(
|
| 228 |
+
self,
|
| 229 |
+
z: torch.Tensor,
|
| 230 |
+
t: torch.Tensor,
|
| 231 |
+
drop_cond_rate: float = 0.0,
|
| 232 |
+
drop_vae_rate: float = 0.0,
|
| 233 |
+
sigma: float = 1e-4,
|
| 234 |
+
):
|
| 235 |
+
z = self.normalize_z(z)
|
| 236 |
+
B, T, D = z.shape
|
| 237 |
+
if self.latent_stride > 1:
|
| 238 |
+
z = z.reshape(B, T // self.latent_stride, D * self.latent_stride)
|
| 239 |
+
|
| 240 |
+
prior, ae_loss = self.autoencoder(z, drop_vae_rate=drop_vae_rate)
|
| 241 |
+
|
| 242 |
+
if drop_cond_rate > 0.0:
|
| 243 |
+
to_drop = torch.rand(prior.shape[0], device=prior.device) < drop_cond_rate
|
| 244 |
+
prior[to_drop] = 0.0
|
| 245 |
+
|
| 246 |
+
x0 = torch.randn_like(z)
|
| 247 |
+
x1 = z
|
| 248 |
+
|
| 249 |
+
flow_target = x1 - (1 - sigma) * x0
|
| 250 |
+
|
| 251 |
+
alpha = (1 - (1 - sigma) * t).view(-1, 1, 1)
|
| 252 |
+
xt = alpha * x0 + t.view(-1, 1, 1) * x1
|
| 253 |
+
|
| 254 |
+
pred = self.flow_net(
|
| 255 |
+
xt.transpose(1, 2),
|
| 256 |
+
prior.transpose(1, 2),
|
| 257 |
+
t,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
flow_loss = nn.functional.mse_loss(flow_target.transpose(1, 2), pred)
|
| 261 |
+
|
| 262 |
+
return flow_loss, ae_loss, prior
|
codec/models/patchvae/modules.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from random import random
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
from vector_quantize_pytorch import FSQ
|
| 8 |
+
|
| 9 |
+
from zcodec.models.components.transformer import TransformerBlock
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AdaLayerNormScale(nn.Module):
|
| 13 |
+
def __init__(self, dim: int):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.linear = nn.Linear(dim, dim * 3)
|
| 16 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
|
| 17 |
+
|
| 18 |
+
def forward(self, x, c):
|
| 19 |
+
x = self.norm(x)
|
| 20 |
+
scale, bias, gate = self.linear(F.silu(c)).chunk(3, dim=1)
|
| 21 |
+
shape = x.shape[0] + [1] * (x.dim() - 2) + x.shape[-1]
|
| 22 |
+
scale, bias, gate = map(lambda x: x.view(*shape), (scale, bias, gate))
|
| 23 |
+
x = x * (1 + scale) + bias
|
| 24 |
+
return x, gate
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class GaussianFourierTimeEmbedding(nn.Module):
|
| 28 |
+
def __init__(self, dim: int):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.weight = nn.Parameter(torch.randn(dim), requires_grad=False)
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
x = x[:, None] * self.weight[None, :] * 2 * torch.pi
|
| 34 |
+
x = torch.cat((torch.sin(x), torch.cos(x)), dim=1)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
LAYER_FACTORIES = {}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def register_flow_layer_factory(name):
|
| 42 |
+
def decorator(fn):
|
| 43 |
+
LAYER_FACTORIES[name] = fn
|
| 44 |
+
return fn
|
| 45 |
+
|
| 46 |
+
return decorator
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@register_flow_layer_factory("convnext")
|
| 50 |
+
def SimpleConvNextFactory(dim: int, i: int, n_layer: int, is_causal: bool = False):
|
| 51 |
+
return ConvNeXtBlock(dim, elementwise_affine_ln=False, is_causal=is_causal)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@register_flow_layer_factory("mlp")
|
| 55 |
+
def MLP(dim: int, i: int, n_layer: int, is_causal: bool = False):
|
| 56 |
+
return AdaLNMLP(dim)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@register_flow_layer_factory("sa_transformer")
|
| 60 |
+
def SelfAttentionTransformer(dim: int, i: int, n_layer: int, is_causal: bool = False):
|
| 61 |
+
return TransformerBlock(dim, 64, elementwise_affine_ln=False, is_causal=is_causal)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def init_weights(m: nn.Module):
|
| 65 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 66 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 67 |
+
nn.init.constant_(m.bias, 0)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def init_adaln_weights(m: nn.Module):
|
| 71 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 72 |
+
nn.init.zeros_(m.bias)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def modulate(x, scale, shift):
|
| 76 |
+
return x * (1 + scale[:, None]) + shift[:, None]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class AdaLNFlowPredictor(nn.Module):
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
feat_dim: int,
|
| 83 |
+
dim: int,
|
| 84 |
+
n_layer: int,
|
| 85 |
+
layer_factory: str,
|
| 86 |
+
cond_dim: int | None = None,
|
| 87 |
+
is_causal: bool = False,
|
| 88 |
+
):
|
| 89 |
+
super().__init__()
|
| 90 |
+
|
| 91 |
+
layer_factory = LAYER_FACTORIES[layer_factory]
|
| 92 |
+
self.layers = nn.ModuleList(
|
| 93 |
+
[
|
| 94 |
+
layer_factory(dim, i, n_layer, is_causal=is_causal)
|
| 95 |
+
for i in range(n_layer)
|
| 96 |
+
]
|
| 97 |
+
)
|
| 98 |
+
if cond_dim is None:
|
| 99 |
+
cond_dim = feat_dim
|
| 100 |
+
self.initial_proj = nn.Linear(feat_dim + cond_dim, dim)
|
| 101 |
+
self.adaln_proj = nn.ModuleList([nn.Linear(dim, dim * 3) for _ in self.layers])
|
| 102 |
+
self.final_adaln_proj = nn.Linear(dim, dim * 2)
|
| 103 |
+
self.out_proj = nn.Linear(dim, feat_dim)
|
| 104 |
+
self.final_norm = nn.LayerNorm(dim, elementwise_affine=False)
|
| 105 |
+
self.time_emb = GaussianFourierTimeEmbedding(dim // 2)
|
| 106 |
+
|
| 107 |
+
self.apply(init_weights)
|
| 108 |
+
for l in self.adaln_proj:
|
| 109 |
+
init_adaln_weights(l)
|
| 110 |
+
init_adaln_weights(self.final_adaln_proj)
|
| 111 |
+
|
| 112 |
+
def forward(
|
| 113 |
+
self,
|
| 114 |
+
x_t: torch.Tensor,
|
| 115 |
+
x_mu: torch.Tensor,
|
| 116 |
+
t: torch.Tensor,
|
| 117 |
+
):
|
| 118 |
+
x_t, x_mu = map(lambda x: x.transpose(1, 2), (x_t, x_mu))
|
| 119 |
+
x = self.initial_proj(torch.cat((x_t, x_mu), dim=-1)).transpose(1, 2)
|
| 120 |
+
|
| 121 |
+
t_emb = self.time_emb(t)
|
| 122 |
+
|
| 123 |
+
for i, (l, adaln) in enumerate(zip(self.layers, self.adaln_proj)):
|
| 124 |
+
scale, shift, gate = F.silu(adaln(t_emb)).chunk(3, dim=1)
|
| 125 |
+
x = l(x, scale_shift=(scale, shift), gate=gate)
|
| 126 |
+
|
| 127 |
+
scale, shift = F.silu(self.final_adaln_proj(t_emb)).chunk(2, dim=1)
|
| 128 |
+
x = self.final_norm(x.transpose(1, 2))
|
| 129 |
+
x = modulate(x, scale, shift)
|
| 130 |
+
|
| 131 |
+
x = self.out_proj(x).transpose(1, 2)
|
| 132 |
+
|
| 133 |
+
return x
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class AdaLNMLP(nn.Module):
|
| 137 |
+
def __init__(self, hidden_dim):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.hidden_dim = hidden_dim
|
| 140 |
+
|
| 141 |
+
self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False)
|
| 142 |
+
self.mlp = nn.Sequential(
|
| 143 |
+
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
| 144 |
+
nn.SiLU(),
|
| 145 |
+
nn.Linear(hidden_dim, hidden_dim, bias=True),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.adaLN_modulation = nn.Sequential(
|
| 149 |
+
nn.SiLU(), nn.Linear(hidden_dim, 4 * hidden_dim, bias=True)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def forward(self, x, scale_shift, gate):
|
| 153 |
+
x = x.transpose(-1, -2)
|
| 154 |
+
h = modulate(self.in_ln(x), *scale_shift)
|
| 155 |
+
h = self.mlp(h)
|
| 156 |
+
return (x + gate[:, None] * h).transpose(-1, -2)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class ConvNeXtBlock(nn.Module):
|
| 160 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
dim (int): Number of input channels.
|
| 164 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
| 165 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 166 |
+
Defaults to None.
|
| 167 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 168 |
+
None means non-conditional LayerNorm. Defaults to None.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
dim: int,
|
| 174 |
+
intermediate_dim: int | None = None,
|
| 175 |
+
layer_scale_init_value: float = 0.0,
|
| 176 |
+
elementwise_affine_ln: bool = True,
|
| 177 |
+
is_causal: bool = False,
|
| 178 |
+
):
|
| 179 |
+
super().__init__()
|
| 180 |
+
intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3
|
| 181 |
+
self.dwconv = nn.Conv1d(
|
| 182 |
+
dim, dim, kernel_size=7, padding=0 if is_causal else 3, groups=dim
|
| 183 |
+
) # depthwise conv
|
| 184 |
+
self.norm = nn.LayerNorm(
|
| 185 |
+
dim, eps=1e-6, elementwise_affine=elementwise_affine_ln
|
| 186 |
+
)
|
| 187 |
+
self.pwconv1 = nn.Linear(
|
| 188 |
+
dim, intermediate_dim
|
| 189 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 190 |
+
self.act = nn.GELU()
|
| 191 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 192 |
+
self.gamma = (
|
| 193 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
| 194 |
+
if layer_scale_init_value > 0
|
| 195 |
+
else None
|
| 196 |
+
)
|
| 197 |
+
self.is_causal = is_causal
|
| 198 |
+
|
| 199 |
+
def forward(
|
| 200 |
+
self,
|
| 201 |
+
x: torch.Tensor,
|
| 202 |
+
scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 203 |
+
gate: torch.Tensor | None = None,
|
| 204 |
+
) -> torch.Tensor:
|
| 205 |
+
residual = x
|
| 206 |
+
if self.is_causal:
|
| 207 |
+
x = torch.nn.functional.pad(x, (6, 0))
|
| 208 |
+
x = self.dwconv(x)
|
| 209 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
| 210 |
+
x = self.norm(x)
|
| 211 |
+
if scale_shift is not None:
|
| 212 |
+
scale, shift = scale_shift
|
| 213 |
+
x = x * scale[:, None] + shift[:, None]
|
| 214 |
+
x = self.pwconv1(x)
|
| 215 |
+
x = self.act(x)
|
| 216 |
+
x = self.pwconv2(x)
|
| 217 |
+
if self.gamma is not None:
|
| 218 |
+
x = self.gamma * x
|
| 219 |
+
if gate is not None:
|
| 220 |
+
x = gate[:, None] * x
|
| 221 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
| 222 |
+
|
| 223 |
+
x = residual + x
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class ConvNextNet(nn.Module):
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
dim: int,
|
| 231 |
+
n_layers: int,
|
| 232 |
+
intermediate_dim: int | None = None,
|
| 233 |
+
is_causal: bool = False,
|
| 234 |
+
):
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.net = nn.Sequential(
|
| 237 |
+
*[
|
| 238 |
+
ConvNeXtBlock(dim, intermediate_dim, is_causal=is_causal)
|
| 239 |
+
for _ in range(n_layers)
|
| 240 |
+
]
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
def forward(self, x):
|
| 244 |
+
return self.net(x.transpose(1, 2)).transpose(1, 2)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def convnext_factory(dim, n_layers, is_causal=False):
|
| 248 |
+
return ConvNextNet(dim, n_layers, is_causal=is_causal)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def convnextformer_factory(
|
| 252 |
+
dim, n_layers, n_convnext_per_transformer_block, is_causal=False
|
| 253 |
+
):
|
| 254 |
+
layers = []
|
| 255 |
+
for i in range(0, n_layers, n_convnext_per_transformer_block + 1):
|
| 256 |
+
layers.append(
|
| 257 |
+
ConvNextNet(dim, n_convnext_per_transformer_block, is_causal=is_causal)
|
| 258 |
+
)
|
| 259 |
+
layers.append(TransformerBlock(dim, 64, is_causal=is_causal))
|
| 260 |
+
return nn.Sequential(*layers)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class AutoEncoder(nn.Module):
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
feat_dim: int,
|
| 267 |
+
hidden_dim: int,
|
| 268 |
+
num_layers: int,
|
| 269 |
+
net_factory: Literal["convnext", "convnextformer_decoder", "convnextformer"],
|
| 270 |
+
out_dim: int | None = None,
|
| 271 |
+
convnextformer_num_conv_per_transformer: int = 3,
|
| 272 |
+
causal_transformer: bool = False,
|
| 273 |
+
bottleneck_size: int | None = None,
|
| 274 |
+
vae: bool = False,
|
| 275 |
+
is_causal: bool = False,
|
| 276 |
+
):
|
| 277 |
+
super().__init__()
|
| 278 |
+
|
| 279 |
+
self.embed = nn.Linear(feat_dim, hidden_dim)
|
| 280 |
+
if out_dim is None:
|
| 281 |
+
out_dim = feat_dim
|
| 282 |
+
self.unembed = nn.Linear(hidden_dim, out_dim)
|
| 283 |
+
|
| 284 |
+
if net_factory == "convnext":
|
| 285 |
+
self.encoder_net = convnext_factory(
|
| 286 |
+
hidden_dim, num_layers, is_causal=is_causal
|
| 287 |
+
)
|
| 288 |
+
self.decoder_net = convnext_factory(
|
| 289 |
+
hidden_dim, num_layers, is_causal=is_causal
|
| 290 |
+
)
|
| 291 |
+
elif net_factory == "convnextformer_decoder":
|
| 292 |
+
self.encoder_net = convnext_factory(
|
| 293 |
+
hidden_dim, num_layers, is_causal=is_causal
|
| 294 |
+
)
|
| 295 |
+
self.decoder_net = convnextformer_factory(
|
| 296 |
+
hidden_dim,
|
| 297 |
+
num_layers,
|
| 298 |
+
convnextformer_num_conv_per_transformer,
|
| 299 |
+
is_causal=is_causal,
|
| 300 |
+
)
|
| 301 |
+
elif net_factory == "convnextformer":
|
| 302 |
+
self.encoder_net = convnextformer_factory(
|
| 303 |
+
hidden_dim,
|
| 304 |
+
num_layers,
|
| 305 |
+
convnextformer_num_conv_per_transformer,
|
| 306 |
+
is_causal=is_causal,
|
| 307 |
+
)
|
| 308 |
+
self.decoder_net = convnextformer_factory(
|
| 309 |
+
hidden_dim,
|
| 310 |
+
num_layers,
|
| 311 |
+
convnextformer_num_conv_per_transformer,
|
| 312 |
+
is_causal=is_causal,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
self.bottleneck = (
|
| 316 |
+
nn.Linear(hidden_dim, bottleneck_size * (1 + vae))
|
| 317 |
+
if bottleneck_size is not None
|
| 318 |
+
else nn.Identity()
|
| 319 |
+
)
|
| 320 |
+
self.unbottleneck = (
|
| 321 |
+
nn.Linear(bottleneck_size, hidden_dim)
|
| 322 |
+
if bottleneck_size is not None
|
| 323 |
+
else nn.Identity()
|
| 324 |
+
)
|
| 325 |
+
self.vae = vae
|
| 326 |
+
|
| 327 |
+
def reparameterize(
|
| 328 |
+
self,
|
| 329 |
+
mu: torch.Tensor,
|
| 330 |
+
logvar: torch.Tensor,
|
| 331 |
+
deterministic: bool = False,
|
| 332 |
+
drop_vae_rate: float = 0.0,
|
| 333 |
+
) -> torch.Tensor:
|
| 334 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
| 335 |
+
std = torch.exp(0.5 * logvar)
|
| 336 |
+
if drop_vae_rate > 0.0:
|
| 337 |
+
to_drop = torch.rand(std.shape[0], device=std.device) < drop_vae_rate
|
| 338 |
+
eps = torch.randn_like(std)
|
| 339 |
+
eps[to_drop] = 0.0
|
| 340 |
+
else:
|
| 341 |
+
if deterministic:
|
| 342 |
+
eps = torch.zeros_like(std)
|
| 343 |
+
else:
|
| 344 |
+
eps = torch.randn_like(std)
|
| 345 |
+
return mu + eps * std
|
| 346 |
+
|
| 347 |
+
def kl_divergence(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
|
| 348 |
+
kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
| 349 |
+
return kl.sum(dim=-1).mean()
|
| 350 |
+
|
| 351 |
+
def forward(self, x: torch.Tensor, drop_vae_rate: float = 0.0) -> torch.Tensor:
|
| 352 |
+
# Encode
|
| 353 |
+
x = self.embed(x)
|
| 354 |
+
x = self.encoder_net(x)
|
| 355 |
+
x = self.bottleneck(x)
|
| 356 |
+
if self.vae:
|
| 357 |
+
mu, logvar = x.chunk(2, dim=-1)
|
| 358 |
+
loss = {
|
| 359 |
+
"kl_div": self.kl_divergence(mu, logvar),
|
| 360 |
+
"_mu_mean": mu.mean(),
|
| 361 |
+
"_mu_std": mu.std(),
|
| 362 |
+
"_logvar_mean": logvar.mean(),
|
| 363 |
+
"_logvar_std": logvar.std(),
|
| 364 |
+
}
|
| 365 |
+
x = self.reparameterize(
|
| 366 |
+
mu,
|
| 367 |
+
logvar,
|
| 368 |
+
drop_vae_rate=drop_vae_rate,
|
| 369 |
+
)
|
| 370 |
+
else:
|
| 371 |
+
loss = {}
|
| 372 |
+
|
| 373 |
+
# Decode
|
| 374 |
+
x = self.unbottleneck(x)
|
| 375 |
+
x = self.decoder_net(x)
|
| 376 |
+
x = self.unembed(x)
|
| 377 |
+
|
| 378 |
+
return x, loss
|
| 379 |
+
|
| 380 |
+
def encode(self, x: torch.Tensor, deterministic: bool = False):
|
| 381 |
+
x = self.embed(x)
|
| 382 |
+
x = self.encoder_net(x)
|
| 383 |
+
x = self.bottleneck(x)
|
| 384 |
+
|
| 385 |
+
if self.vae:
|
| 386 |
+
x = self.reparameterize(*x.chunk(2, dim=-1), deterministic=deterministic)
|
| 387 |
+
return x
|
| 388 |
+
|
| 389 |
+
def decode(
|
| 390 |
+
self,
|
| 391 |
+
latent: torch.Tensor | None = None,
|
| 392 |
+
):
|
| 393 |
+
x = self.unbottleneck(latent)
|
| 394 |
+
x = self.decoder_net(x)
|
| 395 |
+
x = self.unembed(x)
|
| 396 |
+
return x
|
codec/models/wavvae/__init__.py
ADDED
|
File without changes
|
codec/models/wavvae/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (170 Bytes). View file
|
|
|
codec/models/wavvae/__pycache__/discriminators.cpython-312.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
codec/models/wavvae/__pycache__/heads.cpython-312.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
codec/models/wavvae/__pycache__/layers.cpython-312.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
codec/models/wavvae/__pycache__/loss.cpython-312.pyc
ADDED
|
Binary file (7.01 kB). View file
|
|
|
codec/models/wavvae/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (8.79 kB). View file
|
|
|
codec/models/wavvae/__pycache__/modules.cpython-312.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
codec/models/wavvae/__pycache__/spectral_ops.cpython-312.pyc
ADDED
|
Binary file (12.4 kB). View file
|
|
|
codec/models/wavvae/dataset.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
from pytorch_lightning import LightningDataModule
|
| 7 |
+
from torch.utils.data import DataLoader, Dataset
|
| 8 |
+
|
| 9 |
+
torch.set_num_threads(1)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class WavVAEDataConfig:
|
| 14 |
+
filelist_path: str
|
| 15 |
+
sampling_rate: int
|
| 16 |
+
num_samples: int
|
| 17 |
+
batch_size: int
|
| 18 |
+
num_workers: int
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class WavVAEDataModule(LightningDataModule):
|
| 22 |
+
def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.train_config = train_params
|
| 25 |
+
self.val_config = val_params
|
| 26 |
+
|
| 27 |
+
def _get_dataloder(self, cfg: DataConfig, train: bool):
|
| 28 |
+
dataset = WavVAEDataset(cfg, train=train)
|
| 29 |
+
dataloader = DataLoader(
|
| 30 |
+
dataset,
|
| 31 |
+
batch_size=cfg.batch_size,
|
| 32 |
+
num_workers=cfg.num_workers,
|
| 33 |
+
shuffle=train,
|
| 34 |
+
pin_memory=True,
|
| 35 |
+
)
|
| 36 |
+
return dataloader
|
| 37 |
+
|
| 38 |
+
def train_dataloader(self) -> DataLoader:
|
| 39 |
+
return self._get_dataloder(self.train_config, train=True)
|
| 40 |
+
|
| 41 |
+
def val_dataloader(self) -> DataLoader:
|
| 42 |
+
return self._get_dataloder(self.val_config, train=False)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class WavVAEDataset(Dataset):
|
| 46 |
+
def __init__(self, cfg: DataConfig, train: bool):
|
| 47 |
+
with open(cfg.filelist_path) as f:
|
| 48 |
+
self.filelist = f.read().splitlines()
|
| 49 |
+
self.sampling_rate = cfg.sampling_rate
|
| 50 |
+
self.num_samples = cfg.num_samples
|
| 51 |
+
self.train = train
|
| 52 |
+
|
| 53 |
+
def __len__(self) -> int:
|
| 54 |
+
return len(self.filelist)
|
| 55 |
+
|
| 56 |
+
def __getitem__(self, index: int) -> torch.Tensor:
|
| 57 |
+
audio_path = self.filelist[index]
|
| 58 |
+
y, sr = torchaudio.load(audio_path)
|
| 59 |
+
if y.size(0) > 1:
|
| 60 |
+
# mix to mono
|
| 61 |
+
y = y.mean(dim=0, keepdim=True)
|
| 62 |
+
gain = np.random.uniform(-1, -6) if self.train else -3
|
| 63 |
+
y, _ = torchaudio.sox_effects.apply_effects_tensor(
|
| 64 |
+
y, sr, [["norm", f"{gain:.2f}"]]
|
| 65 |
+
)
|
| 66 |
+
try:
|
| 67 |
+
if sr != self.sampling_rate:
|
| 68 |
+
y = torchaudio.functional.resample(
|
| 69 |
+
y, orig_freq=sr, new_freq=self.sampling_rate
|
| 70 |
+
)
|
| 71 |
+
except:
|
| 72 |
+
print(audio_path, y.shape)
|
| 73 |
+
if y.size(-1) < self.num_samples:
|
| 74 |
+
pad_length = self.num_samples - y.size(-1)
|
| 75 |
+
padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
|
| 76 |
+
y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
|
| 77 |
+
elif self.train:
|
| 78 |
+
start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
|
| 79 |
+
y = y[:, start : start + self.num_samples]
|
| 80 |
+
else:
|
| 81 |
+
# During validation, take always the first segment for determinism
|
| 82 |
+
y = y[:, : self.num_samples]
|
| 83 |
+
|
| 84 |
+
return y[0]
|
codec/models/wavvae/discriminators.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import Conv2d
|
| 7 |
+
from torch.nn.utils import weight_norm
|
| 8 |
+
from torchaudio.transforms import Spectrogram
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MultiPeriodDiscriminator(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
|
| 14 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
periods (tuple[int]): Tuple of periods for each discriminator.
|
| 18 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
| 19 |
+
Defaults to None.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])
|
| 25 |
+
|
| 26 |
+
def forward(
|
| 27 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
|
| 28 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
| 29 |
+
y_d_rs = []
|
| 30 |
+
y_d_gs = []
|
| 31 |
+
fmap_rs = []
|
| 32 |
+
fmap_gs = []
|
| 33 |
+
for d in self.discriminators:
|
| 34 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
| 35 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
| 36 |
+
y_d_rs.append(y_d_r)
|
| 37 |
+
fmap_rs.append(fmap_r)
|
| 38 |
+
y_d_gs.append(y_d_g)
|
| 39 |
+
fmap_gs.append(fmap_g)
|
| 40 |
+
|
| 41 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DiscriminatorP(nn.Module):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
period: int,
|
| 48 |
+
in_channels: int = 1,
|
| 49 |
+
kernel_size: int = 5,
|
| 50 |
+
stride: int = 3,
|
| 51 |
+
lrelu_slope: float = 0.1,
|
| 52 |
+
num_embeddings: Optional[int] = None,
|
| 53 |
+
):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.period = period
|
| 56 |
+
self.convs = nn.ModuleList(
|
| 57 |
+
[
|
| 58 |
+
weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 59 |
+
weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 60 |
+
weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 61 |
+
weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 62 |
+
weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
|
| 63 |
+
]
|
| 64 |
+
)
|
| 65 |
+
if num_embeddings is not None:
|
| 66 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024)
|
| 67 |
+
torch.nn.init.zeros_(self.emb.weight)
|
| 68 |
+
|
| 69 |
+
self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 70 |
+
self.lrelu_slope = lrelu_slope
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
| 74 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 75 |
+
x = x.unsqueeze(1)
|
| 76 |
+
fmap = []
|
| 77 |
+
# 1d to 2d
|
| 78 |
+
b, c, t = x.shape
|
| 79 |
+
if t % self.period != 0: # pad first
|
| 80 |
+
n_pad = self.period - (t % self.period)
|
| 81 |
+
x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
|
| 82 |
+
t = t + n_pad
|
| 83 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 84 |
+
|
| 85 |
+
for i, l in enumerate(self.convs):
|
| 86 |
+
x = l(x)
|
| 87 |
+
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
|
| 88 |
+
if i > 0:
|
| 89 |
+
fmap.append(x)
|
| 90 |
+
if cond_embedding_id is not None:
|
| 91 |
+
emb = self.emb(cond_embedding_id)
|
| 92 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
| 93 |
+
else:
|
| 94 |
+
h = 0
|
| 95 |
+
x = self.conv_post(x)
|
| 96 |
+
fmap.append(x)
|
| 97 |
+
x += h
|
| 98 |
+
x = torch.flatten(x, 1, -1)
|
| 99 |
+
|
| 100 |
+
return x, fmap
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class MultiResolutionDiscriminator(nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
| 107 |
+
num_embeddings: Optional[int] = None,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
| 111 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
| 115 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
| 116 |
+
Defaults to None.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.discriminators = nn.ModuleList(
|
| 121 |
+
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def forward(
|
| 125 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
| 126 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
| 127 |
+
y_d_rs = []
|
| 128 |
+
y_d_gs = []
|
| 129 |
+
fmap_rs = []
|
| 130 |
+
fmap_gs = []
|
| 131 |
+
|
| 132 |
+
for d in self.discriminators:
|
| 133 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
| 134 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
| 135 |
+
y_d_rs.append(y_d_r)
|
| 136 |
+
fmap_rs.append(fmap_r)
|
| 137 |
+
y_d_gs.append(y_d_g)
|
| 138 |
+
fmap_gs.append(fmap_g)
|
| 139 |
+
|
| 140 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class DiscriminatorR(nn.Module):
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
window_length: int,
|
| 147 |
+
num_embeddings: Optional[int] = None,
|
| 148 |
+
channels: int = 32,
|
| 149 |
+
hop_factor: float = 0.25,
|
| 150 |
+
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.window_length = window_length
|
| 154 |
+
self.hop_factor = hop_factor
|
| 155 |
+
self.spec_fn = Spectrogram(
|
| 156 |
+
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
| 157 |
+
)
|
| 158 |
+
n_fft = window_length // 2 + 1
|
| 159 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
| 160 |
+
self.bands = bands
|
| 161 |
+
convs = lambda: nn.ModuleList(
|
| 162 |
+
[
|
| 163 |
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
| 164 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 165 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 166 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 167 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
| 168 |
+
]
|
| 169 |
+
)
|
| 170 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
| 171 |
+
|
| 172 |
+
if num_embeddings is not None:
|
| 173 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
| 174 |
+
torch.nn.init.zeros_(self.emb.weight)
|
| 175 |
+
|
| 176 |
+
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
| 177 |
+
|
| 178 |
+
def spectrogram(self, x):
|
| 179 |
+
# Remove DC offset
|
| 180 |
+
x = x - x.mean(dim=-1, keepdims=True)
|
| 181 |
+
# Peak normalize the volume of input audio
|
| 182 |
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
| 183 |
+
x = self.spec_fn(x)
|
| 184 |
+
x = torch.view_as_real(x)
|
| 185 |
+
x = rearrange(x, "b f t c -> b c t f")
|
| 186 |
+
# Split into bands
|
| 187 |
+
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
| 188 |
+
return x_bands
|
| 189 |
+
|
| 190 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
| 191 |
+
x_bands = self.spectrogram(x)
|
| 192 |
+
fmap = []
|
| 193 |
+
x = []
|
| 194 |
+
for band, stack in zip(x_bands, self.band_convs):
|
| 195 |
+
for i, layer in enumerate(stack):
|
| 196 |
+
band = layer(band)
|
| 197 |
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
| 198 |
+
if i > 0:
|
| 199 |
+
fmap.append(band)
|
| 200 |
+
x.append(band)
|
| 201 |
+
x = torch.cat(x, dim=-1)
|
| 202 |
+
if cond_embedding_id is not None:
|
| 203 |
+
emb = self.emb(cond_embedding_id)
|
| 204 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
| 205 |
+
else:
|
| 206 |
+
h = 0
|
| 207 |
+
x = self.conv_post(x)
|
| 208 |
+
fmap.append(x)
|
| 209 |
+
x += h
|
| 210 |
+
|
| 211 |
+
return x, fmap
|
codec/models/wavvae/experiment.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
codec/models/wavvae/heads.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
| 7 |
+
|
| 8 |
+
from .modules import symexp
|
| 9 |
+
from .spectral_ops import IMDCT, ISTFT
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FourierHead(nn.Module):
|
| 13 |
+
"""Base class for inverse fourier modules."""
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
"""
|
| 17 |
+
Args:
|
| 18 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 19 |
+
L is the sequence length, and H denotes the model dimension.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 23 |
+
"""
|
| 24 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LinearNoBiasHead(FourierHead):
|
| 28 |
+
def __init__(self, dim: int, hop_length: int, n_fft: int):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.pre_head = nn.Linear(dim, n_fft + 2)
|
| 31 |
+
self.head = nn.Linear(n_fft + 2, hop_length, bias=False)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
y = self.pre_head(x)
|
| 35 |
+
y = self.head(y).clamp(min=-1.0, max=1.0)
|
| 36 |
+
B, _, _ = y.shape
|
| 37 |
+
return y.reshape(B, -1)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ISTFTHead(FourierHead):
|
| 41 |
+
"""
|
| 42 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
dim (int): Hidden dimension of the model.
|
| 46 |
+
n_fft (int): Size of Fourier transform.
|
| 47 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
| 48 |
+
the resolution of the input features.
|
| 49 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
| 53 |
+
super().__init__()
|
| 54 |
+
out_dim = n_fft + 2
|
| 55 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
| 56 |
+
self.hop_length = hop_length
|
| 57 |
+
self.istft = ISTFT(
|
| 58 |
+
n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
"""
|
| 63 |
+
Forward pass of the ISTFTHead module.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 67 |
+
L is the sequence length, and H denotes the model dimension.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 71 |
+
"""
|
| 72 |
+
x = self.out(x).transpose(1, 2)
|
| 73 |
+
mag, p = x.chunk(2, dim=1)
|
| 74 |
+
mag = torch.exp(mag)
|
| 75 |
+
mag = torch.clip(
|
| 76 |
+
mag, max=1e2
|
| 77 |
+
) # safeguard to prevent excessively large magnitudes
|
| 78 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
| 79 |
+
x = torch.cos(p)
|
| 80 |
+
y = torch.sin(p)
|
| 81 |
+
# recalculating phase here does not produce anything new
|
| 82 |
+
# only costs time
|
| 83 |
+
# phase = torch.atan2(y, x)
|
| 84 |
+
# S = mag * torch.exp(phase * 1j)
|
| 85 |
+
# better directly produce the complex value
|
| 86 |
+
S = mag * (x + 1j * y)
|
| 87 |
+
audio = self.istft(S)
|
| 88 |
+
audio = nn.functional.pad(audio, (self.hop_length // 2, self.hop_length // 2))
|
| 89 |
+
return audio
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class IMDCTSymExpHead(FourierHead):
|
| 93 |
+
"""
|
| 94 |
+
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
dim (int): Hidden dimension of the model.
|
| 98 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
| 99 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 100 |
+
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
| 101 |
+
based on perceptual scaling. Defaults to None.
|
| 102 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
dim: int,
|
| 108 |
+
mdct_frame_len: int,
|
| 109 |
+
padding: str = "same",
|
| 110 |
+
sample_rate: Optional[int] = None,
|
| 111 |
+
clip_audio: bool = False,
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
out_dim = mdct_frame_len // 2
|
| 115 |
+
self.out = nn.Linear(dim, out_dim)
|
| 116 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
| 117 |
+
self.clip_audio = clip_audio
|
| 118 |
+
|
| 119 |
+
if sample_rate is not None:
|
| 120 |
+
# optionally init the last layer following mel-scale
|
| 121 |
+
m_max = _hz_to_mel(sample_rate // 2)
|
| 122 |
+
m_pts = torch.linspace(0, m_max, out_dim)
|
| 123 |
+
f_pts = _mel_to_hz(m_pts)
|
| 124 |
+
scale = 1 - (f_pts / f_pts.max())
|
| 125 |
+
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
self.out.weight.mul_(scale.view(-1, 1))
|
| 128 |
+
|
| 129 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 130 |
+
"""
|
| 131 |
+
Forward pass of the IMDCTSymExpHead module.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 135 |
+
L is the sequence length, and H denotes the model dimension.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 139 |
+
"""
|
| 140 |
+
x = self.out(x)
|
| 141 |
+
x = symexp(x)
|
| 142 |
+
x = torch.clip(
|
| 143 |
+
x, min=-1e2, max=1e2
|
| 144 |
+
) # safeguard to prevent excessively large magnitudes
|
| 145 |
+
audio = self.imdct(x)
|
| 146 |
+
if self.clip_audio:
|
| 147 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
| 148 |
+
|
| 149 |
+
return audio
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class IMDCTCosHead(FourierHead):
|
| 153 |
+
"""
|
| 154 |
+
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
dim (int): Hidden dimension of the model.
|
| 158 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
| 159 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 160 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
dim: int,
|
| 166 |
+
mdct_frame_len: int,
|
| 167 |
+
padding: str = "same",
|
| 168 |
+
clip_audio: bool = False,
|
| 169 |
+
):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.clip_audio = clip_audio
|
| 172 |
+
self.out = nn.Linear(dim, mdct_frame_len)
|
| 173 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
| 174 |
+
|
| 175 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 176 |
+
"""
|
| 177 |
+
Forward pass of the IMDCTCosHead module.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 181 |
+
L is the sequence length, and H denotes the model dimension.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 185 |
+
"""
|
| 186 |
+
x = self.out(x)
|
| 187 |
+
m, p = x.chunk(2, dim=2)
|
| 188 |
+
m = torch.exp(m).clip(
|
| 189 |
+
max=1e2
|
| 190 |
+
) # safeguard to prevent excessively large magnitudes
|
| 191 |
+
audio = self.imdct(m * torch.cos(p))
|
| 192 |
+
if self.clip_audio:
|
| 193 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
| 194 |
+
return audio
|
codec/models/wavvae/helpers.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from matplotlib import pyplot as plt
|
| 5 |
+
from pytorch_lightning import Callback
|
| 6 |
+
|
| 7 |
+
matplotlib.use("Agg")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
|
| 11 |
+
"""
|
| 12 |
+
Save a matplotlib figure to a numpy array.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
fig (Figure): Matplotlib figure object.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
ndarray: Numpy array representing the figure.
|
| 19 |
+
"""
|
| 20 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
| 21 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
| 22 |
+
return data
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
|
| 26 |
+
"""
|
| 27 |
+
Plot a spectrogram and convert it to a numpy array.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
spectrogram (ndarray): Spectrogram data.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
ndarray: Numpy array representing the plotted spectrogram.
|
| 34 |
+
"""
|
| 35 |
+
spectrogram = spectrogram.astype(np.float32)
|
| 36 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
| 37 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
| 38 |
+
plt.colorbar(im, ax=ax)
|
| 39 |
+
plt.xlabel("Frames")
|
| 40 |
+
plt.ylabel("Channels")
|
| 41 |
+
plt.tight_layout()
|
| 42 |
+
|
| 43 |
+
fig.canvas.draw()
|
| 44 |
+
data = save_figure_to_numpy(fig)
|
| 45 |
+
plt.close()
|
| 46 |
+
return data
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class GradNormCallback(Callback):
|
| 50 |
+
"""
|
| 51 |
+
Callback to log the gradient norm.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def on_after_backward(self, trainer, model):
|
| 55 |
+
model.log("grad_norm", gradient_norm(model))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Compute the gradient norm.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
model (Module): PyTorch model.
|
| 64 |
+
norm_type (float, optional): Type of the norm. Defaults to 2.0.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Tensor: Gradient norm.
|
| 68 |
+
"""
|
| 69 |
+
grads = [p.grad for p in model.parameters() if p.grad is not None]
|
| 70 |
+
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
|
| 71 |
+
return total_norm
|
codec/models/wavvae/layers.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 6 |
+
|
| 7 |
+
class VocosDecoder(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
dim: int,
|
| 11 |
+
intermediate_dim: int,
|
| 12 |
+
num_layers: int,
|
| 13 |
+
):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 16 |
+
self.convnext = nn.ModuleList(
|
| 17 |
+
[
|
| 18 |
+
ConvNeXtBlock(
|
| 19 |
+
dim=dim,
|
| 20 |
+
intermediate_dim=intermediate_dim,
|
| 21 |
+
layer_scale_init_value=0.0,
|
| 22 |
+
)
|
| 23 |
+
for _ in range(num_layers)
|
| 24 |
+
]
|
| 25 |
+
)
|
| 26 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
| 27 |
+
|
| 28 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
x = self.norm(x)
|
| 30 |
+
x = x.transpose(1, 2)
|
| 31 |
+
for conv_block in self.convnext:
|
| 32 |
+
x = conv_block(x)
|
| 33 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
class ConvNeXtBlock(nn.Module):
|
| 37 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
dim (int): Number of input channels.
|
| 41 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
| 42 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 43 |
+
Defaults to None.
|
| 44 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 45 |
+
None means non-conditional LayerNorm. Defaults to None.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
dim: int,
|
| 51 |
+
intermediate_dim: int | None = None,
|
| 52 |
+
layer_scale_init_value: float = 0.0,
|
| 53 |
+
elementwise_affine_ln: bool = True,
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3
|
| 57 |
+
self.dwconv = nn.Conv1d(
|
| 58 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
| 59 |
+
) # depthwise conv
|
| 60 |
+
self.norm = nn.LayerNorm(
|
| 61 |
+
dim, eps=1e-6, elementwise_affine=elementwise_affine_ln
|
| 62 |
+
)
|
| 63 |
+
self.pwconv1 = nn.Linear(
|
| 64 |
+
dim, intermediate_dim
|
| 65 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 66 |
+
self.act = nn.GELU()
|
| 67 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 68 |
+
self.gamma = (
|
| 69 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
| 70 |
+
if layer_scale_init_value > 0
|
| 71 |
+
else None
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(
|
| 75 |
+
self,
|
| 76 |
+
x: torch.Tensor,
|
| 77 |
+
scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 78 |
+
gate: torch.Tensor | None = None,
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
residual = x
|
| 81 |
+
x = self.dwconv(x)
|
| 82 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
| 83 |
+
x = self.norm(x)
|
| 84 |
+
if scale_shift is not None:
|
| 85 |
+
scale, shift = scale_shift
|
| 86 |
+
x = x * scale[:, None] + shift[:, None]
|
| 87 |
+
x = self.pwconv1(x)
|
| 88 |
+
x = self.act(x)
|
| 89 |
+
x = self.pwconv2(x)
|
| 90 |
+
if self.gamma is not None:
|
| 91 |
+
x = self.gamma * x
|
| 92 |
+
if gate is not None:
|
| 93 |
+
x = gate[:, None] * x
|
| 94 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
| 95 |
+
|
| 96 |
+
x = residual + x
|
| 97 |
+
return x
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class Encoder(nn.Module):
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
d_model=32,
|
| 104 |
+
strides=[2, 4, 4, 8],
|
| 105 |
+
depthwise=False,
|
| 106 |
+
):
|
| 107 |
+
super().__init__()
|
| 108 |
+
layers = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
| 109 |
+
for stride in strides:
|
| 110 |
+
d_model *= 2
|
| 111 |
+
groups = d_model // 2 if depthwise else 1
|
| 112 |
+
layers += [EncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
|
| 113 |
+
groups = d_model if depthwise else 1
|
| 114 |
+
layers += [
|
| 115 |
+
WNConv1d(d_model, d_model, kernel_size=7, padding=3, groups=groups),
|
| 116 |
+
]
|
| 117 |
+
self.block = nn.Sequential(*layers)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
return self.block(x)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class Decoder(nn.Module):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
input_channel,
|
| 127 |
+
channels,
|
| 128 |
+
rates,
|
| 129 |
+
noise=False,
|
| 130 |
+
depthwise=False,
|
| 131 |
+
d_out=1,
|
| 132 |
+
):
|
| 133 |
+
super().__init__()
|
| 134 |
+
if depthwise:
|
| 135 |
+
layers = [
|
| 136 |
+
WNConv1d(
|
| 137 |
+
input_channel,
|
| 138 |
+
input_channel,
|
| 139 |
+
kernel_size=7,
|
| 140 |
+
padding=3,
|
| 141 |
+
groups=input_channel,
|
| 142 |
+
),
|
| 143 |
+
WNConv1d(input_channel, channels, kernel_size=1),
|
| 144 |
+
]
|
| 145 |
+
else:
|
| 146 |
+
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
| 147 |
+
|
| 148 |
+
for i, stride in enumerate(rates):
|
| 149 |
+
input_dim = channels // 2**i
|
| 150 |
+
output_dim = channels // 2 ** (i + 1)
|
| 151 |
+
groups = output_dim if depthwise else 1
|
| 152 |
+
layers.append(
|
| 153 |
+
DecoderBlock(input_dim, output_dim, stride, noise, groups=groups)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
layers += [
|
| 157 |
+
Snake1d(output_dim),
|
| 158 |
+
WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
| 159 |
+
nn.Tanh(),
|
| 160 |
+
]
|
| 161 |
+
self.model = nn.Sequential(*layers)
|
| 162 |
+
|
| 163 |
+
def forward(self, x):
|
| 164 |
+
x = self.model(x)
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ResidualUnit(nn.Module):
|
| 169 |
+
def __init__(self, dim=16, dilation=1, kernel=7, groups=1):
|
| 170 |
+
super().__init__()
|
| 171 |
+
pad = ((kernel - 1) * dilation) // 2
|
| 172 |
+
self.block = nn.Sequential(
|
| 173 |
+
Snake1d(dim),
|
| 174 |
+
WNConv1d(
|
| 175 |
+
dim,
|
| 176 |
+
dim,
|
| 177 |
+
kernel_size=kernel,
|
| 178 |
+
dilation=dilation,
|
| 179 |
+
padding=pad,
|
| 180 |
+
groups=groups,
|
| 181 |
+
),
|
| 182 |
+
Snake1d(dim),
|
| 183 |
+
WNConv1d(dim, dim, kernel_size=1),
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def forward(self, x):
|
| 187 |
+
y = self.block(x)
|
| 188 |
+
pad = (x.shape[-1] - y.shape[-1]) // 2
|
| 189 |
+
if pad > 0:
|
| 190 |
+
x = x[..., pad:-pad]
|
| 191 |
+
return x + y
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class EncoderBlock(nn.Module):
|
| 195 |
+
def __init__(self, output_dim=16, input_dim=None, stride=1, groups=1):
|
| 196 |
+
super().__init__()
|
| 197 |
+
input_dim = input_dim or output_dim // 2
|
| 198 |
+
self.block = nn.Sequential(
|
| 199 |
+
ResidualUnit(input_dim, dilation=1, groups=groups),
|
| 200 |
+
ResidualUnit(input_dim, dilation=3, groups=groups),
|
| 201 |
+
ResidualUnit(input_dim, dilation=9, groups=groups),
|
| 202 |
+
Snake1d(input_dim),
|
| 203 |
+
WNConv1d(
|
| 204 |
+
input_dim,
|
| 205 |
+
output_dim,
|
| 206 |
+
kernel_size=2 * stride,
|
| 207 |
+
stride=stride,
|
| 208 |
+
padding=math.ceil(stride / 2),
|
| 209 |
+
),
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def forward(self, x):
|
| 213 |
+
return self.block(x)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class NoiseBlock(nn.Module):
|
| 217 |
+
def __init__(self, dim):
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.linear = WNConv1d(dim, dim, kernel_size=1, bias=False)
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
B, C, T = x.shape
|
| 223 |
+
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
|
| 224 |
+
h = self.linear(x)
|
| 225 |
+
n = noise * h
|
| 226 |
+
x = x + n
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class DecoderBlock(nn.Module):
|
| 231 |
+
def __init__(self, input_dim=16, output_dim=8, stride=1, noise=False, groups=1):
|
| 232 |
+
super().__init__()
|
| 233 |
+
layers = [
|
| 234 |
+
Snake1d(input_dim),
|
| 235 |
+
WNConvTranspose1d(
|
| 236 |
+
input_dim,
|
| 237 |
+
output_dim,
|
| 238 |
+
kernel_size=2 * stride,
|
| 239 |
+
stride=stride,
|
| 240 |
+
padding=math.ceil(stride / 2),
|
| 241 |
+
output_padding=stride % 2,
|
| 242 |
+
),
|
| 243 |
+
]
|
| 244 |
+
if noise:
|
| 245 |
+
layers.append(NoiseBlock(output_dim))
|
| 246 |
+
layers.extend(
|
| 247 |
+
[
|
| 248 |
+
ResidualUnit(output_dim, dilation=1, groups=groups),
|
| 249 |
+
ResidualUnit(output_dim, dilation=3, groups=groups),
|
| 250 |
+
ResidualUnit(output_dim, dilation=9, groups=groups),
|
| 251 |
+
]
|
| 252 |
+
)
|
| 253 |
+
self.block = nn.Sequential(*layers)
|
| 254 |
+
|
| 255 |
+
def forward(self, x):
|
| 256 |
+
return self.block(x)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def WNConv1d(*args, **kwargs):
|
| 260 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 264 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
@torch.jit.script
|
| 268 |
+
def snake(x, alpha):
|
| 269 |
+
shape = x.shape
|
| 270 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 271 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 272 |
+
x = x.reshape(shape)
|
| 273 |
+
return x
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class Snake1d(nn.Module):
|
| 277 |
+
def __init__(self, channels):
|
| 278 |
+
super().__init__()
|
| 279 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 280 |
+
|
| 281 |
+
def forward(self, x):
|
| 282 |
+
return snake(x, self.alpha)
|
codec/models/wavvae/loss.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from .modules import safe_log
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MelSpecReconstructionLoss(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
sample_rate: int = 24000,
|
| 18 |
+
n_fft: int | None = None,
|
| 19 |
+
hop_length: int = 256,
|
| 20 |
+
n_mels: int = 100,
|
| 21 |
+
f_min: int = 0,
|
| 22 |
+
f_max: Optional[int] = None,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
| 26 |
+
sample_rate=sample_rate,
|
| 27 |
+
n_fft=hop_length * 4 if n_fft is None else n_fft,
|
| 28 |
+
hop_length=hop_length,
|
| 29 |
+
n_mels=n_mels,
|
| 30 |
+
center=True,
|
| 31 |
+
power=1,
|
| 32 |
+
f_min=f_min,
|
| 33 |
+
f_max=f_max,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, y_hat, y) -> torch.Tensor:
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
y_hat (Tensor): Predicted audio waveform.
|
| 40 |
+
y (Tensor): Ground truth audio waveform.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
|
| 44 |
+
"""
|
| 45 |
+
# B, C, Th = y_hat.shape
|
| 46 |
+
# B, C, T = y.shape
|
| 47 |
+
# crop = (Th - T) // 2
|
| 48 |
+
mel_hat = safe_log(self.mel_spec(y_hat))
|
| 49 |
+
# mel_hat = safe_log(self.mel_spec(y_hat[..., crop:-crop]))
|
| 50 |
+
# mel = safe_log(self.mel_spec(y[..., crop:-crop]))
|
| 51 |
+
mel = safe_log(self.mel_spec(y))
|
| 52 |
+
|
| 53 |
+
loss = torch.nn.functional.l1_loss(mel, mel_hat)
|
| 54 |
+
|
| 55 |
+
return loss
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class GeneratorLoss(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def forward(
|
| 64 |
+
self, disc_outputs: List[torch.Tensor]
|
| 65 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 66 |
+
"""
|
| 67 |
+
Args:
|
| 68 |
+
disc_outputs (List[Tensor]): List of discriminator outputs.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
|
| 72 |
+
the sub-discriminators
|
| 73 |
+
"""
|
| 74 |
+
loss = torch.zeros(
|
| 75 |
+
1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype
|
| 76 |
+
)
|
| 77 |
+
gen_losses = []
|
| 78 |
+
for dg in disc_outputs:
|
| 79 |
+
l = torch.mean(torch.clamp(1 - dg, min=0))
|
| 80 |
+
gen_losses.append(l)
|
| 81 |
+
loss += l
|
| 82 |
+
|
| 83 |
+
return loss, gen_losses
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class DiscriminatorLoss(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def forward(
|
| 92 |
+
self,
|
| 93 |
+
disc_real_outputs: List[torch.Tensor],
|
| 94 |
+
disc_generated_outputs: List[torch.Tensor],
|
| 95 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
| 96 |
+
"""
|
| 97 |
+
Args:
|
| 98 |
+
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
|
| 99 |
+
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
|
| 103 |
+
the sub-discriminators for real outputs, and a list of
|
| 104 |
+
loss values for generated outputs.
|
| 105 |
+
"""
|
| 106 |
+
loss = torch.zeros(
|
| 107 |
+
1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype
|
| 108 |
+
)
|
| 109 |
+
r_losses = []
|
| 110 |
+
g_losses = []
|
| 111 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 112 |
+
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
| 113 |
+
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
| 114 |
+
loss += r_loss + g_loss
|
| 115 |
+
r_losses.append(r_loss)
|
| 116 |
+
g_losses.append(g_loss)
|
| 117 |
+
|
| 118 |
+
return loss, r_losses, g_losses
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class FeatureMatchingLoss(nn.Module):
|
| 122 |
+
"""
|
| 123 |
+
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def forward(
|
| 127 |
+
self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
|
| 128 |
+
) -> torch.Tensor:
|
| 129 |
+
"""
|
| 130 |
+
Args:
|
| 131 |
+
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
|
| 132 |
+
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
Tensor: The calculated feature matching loss.
|
| 136 |
+
"""
|
| 137 |
+
loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
|
| 138 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 139 |
+
for rl, gl in zip(dr, dg):
|
| 140 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 141 |
+
|
| 142 |
+
return loss
|
codec/models/wavvae/model.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Literal
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from safetensors.torch import load_file
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from .heads import ISTFTHead, LinearNoBiasHead
|
| 11 |
+
from .layers import Encoder, VocosDecoder
|
| 12 |
+
from .modules import ConvNeXtBlock
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class WavVAEConfig:
|
| 17 |
+
conv_dim: int = 48
|
| 18 |
+
latent_dim: int = 32
|
| 19 |
+
decoder_hidden_dim: int = 768
|
| 20 |
+
decoder_intermediate_dim: int = 1536
|
| 21 |
+
decoder_num_layers: int = 8
|
| 22 |
+
n_fft: int = 1024
|
| 23 |
+
hop_length: int = 256
|
| 24 |
+
padding: str = "center"
|
| 25 |
+
head_type: Literal["istft", "linear"] = "istft"
|
| 26 |
+
strides: list[int] = field(default_factory=lambda: [2, 4, 4, 8])
|
| 27 |
+
learnable_pre_norm: bool = False
|
| 28 |
+
sampling_rate: int = 24000
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class WavVAE(nn.Module):
|
| 32 |
+
def __init__(self, cfg: WavVAEConfig):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.conv_encoder = Encoder(cfg.conv_dim, strides=cfg.strides, depthwise=True)
|
| 35 |
+
conv_final_dim = cfg.conv_dim * 2 ** len(cfg.strides)
|
| 36 |
+
self.bottleneck = nn.Linear(conv_final_dim, cfg.latent_dim * 2)
|
| 37 |
+
self.unbottleneck = nn.Linear(cfg.latent_dim, cfg.decoder_hidden_dim)
|
| 38 |
+
self.latent_norm = nn.LayerNorm(conv_final_dim)
|
| 39 |
+
self.vocos_decoder = VocosDecoder(
|
| 40 |
+
cfg.decoder_hidden_dim,
|
| 41 |
+
cfg.decoder_intermediate_dim,
|
| 42 |
+
cfg.decoder_num_layers,
|
| 43 |
+
)
|
| 44 |
+
if cfg.head_type == "istft":
|
| 45 |
+
self.head = ISTFTHead(
|
| 46 |
+
cfg.decoder_hidden_dim,
|
| 47 |
+
cfg.n_fft,
|
| 48 |
+
cfg.hop_length,
|
| 49 |
+
padding=cfg.padding,
|
| 50 |
+
)
|
| 51 |
+
elif cfg.head_type == "linear":
|
| 52 |
+
self.head = LinearNoBiasHead(
|
| 53 |
+
cfg.decoder_hidden_dim,
|
| 54 |
+
cfg.hop_length,
|
| 55 |
+
cfg.n_fft,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self._sampling_rate = cfg.sampling_rate
|
| 59 |
+
self._strides = cfg.strides
|
| 60 |
+
self.apply(self._init_weights)
|
| 61 |
+
|
| 62 |
+
def _init_weights(self, m):
|
| 63 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 64 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 65 |
+
if m.bias is not None:
|
| 66 |
+
nn.init.constant_(m.bias, 0)
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def sampling_rate(self) -> int:
|
| 70 |
+
return self._sampling_rate
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def hop_length(self) -> int:
|
| 74 |
+
hop_length = 1
|
| 75 |
+
for s in self._strides:
|
| 76 |
+
hop_length *= s
|
| 77 |
+
return hop_length
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def frame_rate(self) -> float:
|
| 81 |
+
return self.sampling_rate / self.hop_length
|
| 82 |
+
|
| 83 |
+
@classmethod
|
| 84 |
+
def from_pretrained(
|
| 85 |
+
cls,
|
| 86 |
+
pretrained_model_name_or_path: str,
|
| 87 |
+
device: str = "cpu",
|
| 88 |
+
):
|
| 89 |
+
if Path(pretrained_model_name_or_path).exists():
|
| 90 |
+
path = pretrained_model_name_or_path
|
| 91 |
+
else:
|
| 92 |
+
from huggingface_hub import snapshot_download
|
| 93 |
+
|
| 94 |
+
path = snapshot_download(pretrained_model_name_or_path)
|
| 95 |
+
|
| 96 |
+
with open(Path(path) / "config.json", "r") as f:
|
| 97 |
+
config = json.load(f)
|
| 98 |
+
config = WavVAEConfig(**config)
|
| 99 |
+
model = cls(config)
|
| 100 |
+
state_dict = load_file(
|
| 101 |
+
Path(path) / "model.st",
|
| 102 |
+
device=device,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
model.load_state_dict(state_dict, assign=True)
|
| 106 |
+
return model
|
| 107 |
+
|
| 108 |
+
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
|
| 109 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
| 110 |
+
std = torch.exp(0.5 * logvar)
|
| 111 |
+
eps = torch.randn_like(std)
|
| 112 |
+
return mu + eps * std
|
| 113 |
+
|
| 114 |
+
def kl_divergence(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
| 116 |
+
return kl.sum(dim=-1).mean()
|
| 117 |
+
|
| 118 |
+
def encode(self, audio: torch.Tensor) -> torch.Tensor:
|
| 119 |
+
y = self.conv_encoder(audio.unsqueeze(1)).transpose(1, 2)
|
| 120 |
+
y = self.latent_norm(y)
|
| 121 |
+
mu, logvar = self.bottleneck(y).chunk(2, dim=-1)
|
| 122 |
+
z = self.reparameterize(mu, logvar)
|
| 123 |
+
return z
|
| 124 |
+
|
| 125 |
+
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
| 126 |
+
y = self.unbottleneck(z)
|
| 127 |
+
y = self.vocos_decoder(y)
|
| 128 |
+
return self.head(y)
|
| 129 |
+
|
| 130 |
+
def forward(self, audio_input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 131 |
+
y = self.conv_encoder(audio_input.unsqueeze(1)).transpose(1, 2)
|
| 132 |
+
y = self.latent_norm(y)
|
| 133 |
+
mu, logvar = self.bottleneck(y).chunk(2, dim=-1)
|
| 134 |
+
kl_div = self.kl_divergence(mu, logvar)
|
| 135 |
+
z = self.reparameterize(mu, logvar)
|
| 136 |
+
y = self.unbottleneck(z)
|
| 137 |
+
y = self.vocos_decoder(y)
|
| 138 |
+
audio_output = self.head(y)
|
| 139 |
+
|
| 140 |
+
return audio_output, kl_div
|
codec/models/wavvae/modules.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConvNeXtBlock(nn.Module):
|
| 9 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
dim (int): Number of input channels.
|
| 13 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
| 14 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 15 |
+
Defaults to None.
|
| 16 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 17 |
+
None means non-conditional LayerNorm. Defaults to None.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
dim: int,
|
| 23 |
+
intermediate_dim: int,
|
| 24 |
+
layer_scale_init_value: float,
|
| 25 |
+
adanorm_num_embeddings: Optional[int] = None,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
| 29 |
+
self.adanorm = adanorm_num_embeddings is not None
|
| 30 |
+
if adanorm_num_embeddings:
|
| 31 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
| 32 |
+
else:
|
| 33 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 34 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
| 35 |
+
self.act = nn.GELU()
|
| 36 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 37 |
+
self.gamma = (
|
| 38 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
| 39 |
+
if layer_scale_init_value > 0
|
| 40 |
+
else None
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 44 |
+
residual = x
|
| 45 |
+
x = self.dwconv(x)
|
| 46 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
| 47 |
+
if self.adanorm:
|
| 48 |
+
assert cond_embedding_id is not None
|
| 49 |
+
x = self.norm(x, cond_embedding_id)
|
| 50 |
+
else:
|
| 51 |
+
x = self.norm(x)
|
| 52 |
+
x = self.pwconv1(x)
|
| 53 |
+
x = self.act(x)
|
| 54 |
+
x = self.pwconv2(x)
|
| 55 |
+
if self.gamma is not None:
|
| 56 |
+
x = self.gamma * x
|
| 57 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
| 58 |
+
|
| 59 |
+
x = residual + x
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AdaLayerNorm(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
num_embeddings (int): Number of embeddings.
|
| 69 |
+
embedding_dim (int): Dimension of the embeddings.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.eps = eps
|
| 75 |
+
self.dim = embedding_dim
|
| 76 |
+
self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
| 77 |
+
self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
| 78 |
+
torch.nn.init.ones_(self.scale.weight)
|
| 79 |
+
torch.nn.init.zeros_(self.shift.weight)
|
| 80 |
+
|
| 81 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
scale = self.scale(cond_embedding_id)
|
| 83 |
+
shift = self.shift(cond_embedding_id)
|
| 84 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
| 85 |
+
x = x * scale + shift
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ResBlock1(nn.Module):
|
| 90 |
+
"""
|
| 91 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
| 92 |
+
but without upsampling layers.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
dim (int): Number of input channels.
|
| 96 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
| 97 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
| 98 |
+
Defaults to (1, 3, 5).
|
| 99 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
| 100 |
+
Defaults to 0.1.
|
| 101 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 102 |
+
Defaults to None.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
dim: int,
|
| 108 |
+
kernel_size: int = 3,
|
| 109 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
| 110 |
+
lrelu_slope: float = 0.1,
|
| 111 |
+
layer_scale_init_value: Optional[float] = None,
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.lrelu_slope = lrelu_slope
|
| 115 |
+
self.convs1 = nn.ModuleList(
|
| 116 |
+
[
|
| 117 |
+
weight_norm(
|
| 118 |
+
nn.Conv1d(
|
| 119 |
+
dim,
|
| 120 |
+
dim,
|
| 121 |
+
kernel_size,
|
| 122 |
+
1,
|
| 123 |
+
dilation=dilation[0],
|
| 124 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
| 125 |
+
)
|
| 126 |
+
),
|
| 127 |
+
weight_norm(
|
| 128 |
+
nn.Conv1d(
|
| 129 |
+
dim,
|
| 130 |
+
dim,
|
| 131 |
+
kernel_size,
|
| 132 |
+
1,
|
| 133 |
+
dilation=dilation[1],
|
| 134 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
| 135 |
+
)
|
| 136 |
+
),
|
| 137 |
+
weight_norm(
|
| 138 |
+
nn.Conv1d(
|
| 139 |
+
dim,
|
| 140 |
+
dim,
|
| 141 |
+
kernel_size,
|
| 142 |
+
1,
|
| 143 |
+
dilation=dilation[2],
|
| 144 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
| 145 |
+
)
|
| 146 |
+
),
|
| 147 |
+
]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.convs2 = nn.ModuleList(
|
| 151 |
+
[
|
| 152 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
| 153 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
| 154 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
| 155 |
+
]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.gamma = nn.ParameterList(
|
| 159 |
+
[
|
| 160 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
| 161 |
+
if layer_scale_init_value is not None
|
| 162 |
+
else None,
|
| 163 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
| 164 |
+
if layer_scale_init_value is not None
|
| 165 |
+
else None,
|
| 166 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
| 167 |
+
if layer_scale_init_value is not None
|
| 168 |
+
else None,
|
| 169 |
+
]
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 173 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
| 174 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
| 175 |
+
xt = c1(xt)
|
| 176 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
| 177 |
+
xt = c2(xt)
|
| 178 |
+
if gamma is not None:
|
| 179 |
+
xt = gamma * xt
|
| 180 |
+
x = xt + x
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
def remove_weight_norm(self):
|
| 184 |
+
for l in self.convs1:
|
| 185 |
+
remove_weight_norm(l)
|
| 186 |
+
for l in self.convs2:
|
| 187 |
+
remove_weight_norm(l)
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
| 191 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
| 195 |
+
"""
|
| 196 |
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
x (Tensor): Input tensor.
|
| 200 |
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
| 204 |
+
"""
|
| 205 |
+
return torch.log(torch.clip(x, min=clip_val))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
| 209 |
+
return torch.sign(x) * torch.log1p(x.abs())
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
| 213 |
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
codec/models/wavvae/spectral_ops.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import scipy
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, view_as_real, view_as_complex
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ISTFT(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
| 10 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
| 11 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
| 12 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
| 13 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
n_fft (int): Size of Fourier transform.
|
| 17 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
| 18 |
+
win_length (int): The size of window frame and STFT filter.
|
| 19 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
|
| 23 |
+
super().__init__()
|
| 24 |
+
if padding not in ["center", "same"]:
|
| 25 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 26 |
+
self.padding = padding
|
| 27 |
+
self.n_fft = n_fft
|
| 28 |
+
self.hop_length = hop_length
|
| 29 |
+
self.win_length = win_length
|
| 30 |
+
window = torch.hann_window(win_length)
|
| 31 |
+
self.register_buffer("window", window)
|
| 32 |
+
|
| 33 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
"""
|
| 35 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
| 39 |
+
N is the number of frequency bins, and T is the number of time frames.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
| 43 |
+
"""
|
| 44 |
+
if self.padding == "center":
|
| 45 |
+
# Fallback to pytorch native implementation
|
| 46 |
+
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
|
| 47 |
+
elif self.padding == "same":
|
| 48 |
+
pad = (self.win_length - self.hop_length) // 2
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 51 |
+
|
| 52 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
| 53 |
+
B, N, T = spec.shape
|
| 54 |
+
|
| 55 |
+
# Inverse FFT
|
| 56 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
| 57 |
+
ifft = ifft * self.window[None, :, None]
|
| 58 |
+
|
| 59 |
+
# Overlap and Add
|
| 60 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
| 61 |
+
y = torch.nn.functional.fold(
|
| 62 |
+
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
| 63 |
+
)[:, 0, 0, pad:-pad]
|
| 64 |
+
|
| 65 |
+
# Window envelope
|
| 66 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
| 67 |
+
window_envelope = torch.nn.functional.fold(
|
| 68 |
+
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
| 69 |
+
).squeeze()[pad:-pad]
|
| 70 |
+
|
| 71 |
+
# Normalize
|
| 72 |
+
assert (window_envelope > 1e-11).all()
|
| 73 |
+
y = y / window_envelope
|
| 74 |
+
|
| 75 |
+
return y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MDCT(nn.Module):
|
| 79 |
+
"""
|
| 80 |
+
Modified Discrete Cosine Transform (MDCT) module.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
frame_len (int): Length of the MDCT frame.
|
| 84 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
| 88 |
+
super().__init__()
|
| 89 |
+
if padding not in ["center", "same"]:
|
| 90 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 91 |
+
self.padding = padding
|
| 92 |
+
self.frame_len = frame_len
|
| 93 |
+
N = frame_len // 2
|
| 94 |
+
n0 = (N + 1) / 2
|
| 95 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
| 96 |
+
self.register_buffer("window", window)
|
| 97 |
+
|
| 98 |
+
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
| 99 |
+
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
| 100 |
+
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
| 101 |
+
# https://github.com/pytorch/pytorch/issues/71613
|
| 102 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
| 103 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
| 104 |
+
|
| 105 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
"""
|
| 107 |
+
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
| 111 |
+
and T is the length of the audio.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
| 115 |
+
and N is the number of frequency bins.
|
| 116 |
+
"""
|
| 117 |
+
if self.padding == "center":
|
| 118 |
+
audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
|
| 119 |
+
elif self.padding == "same":
|
| 120 |
+
# hop_length is 1/2 frame_len
|
| 121 |
+
audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
|
| 122 |
+
else:
|
| 123 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 124 |
+
|
| 125 |
+
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
| 126 |
+
N = self.frame_len // 2
|
| 127 |
+
x = x * self.window.expand(x.shape)
|
| 128 |
+
X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
|
| 129 |
+
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
| 130 |
+
return torch.real(res) * np.sqrt(2)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class IMDCT(nn.Module):
|
| 134 |
+
"""
|
| 135 |
+
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
frame_len (int): Length of the MDCT frame.
|
| 139 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
| 143 |
+
super().__init__()
|
| 144 |
+
if padding not in ["center", "same"]:
|
| 145 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 146 |
+
self.padding = padding
|
| 147 |
+
self.frame_len = frame_len
|
| 148 |
+
N = frame_len // 2
|
| 149 |
+
n0 = (N + 1) / 2
|
| 150 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
| 151 |
+
self.register_buffer("window", window)
|
| 152 |
+
|
| 153 |
+
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
| 154 |
+
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
| 155 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
| 156 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
| 157 |
+
|
| 158 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
| 159 |
+
"""
|
| 160 |
+
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
| 164 |
+
L is the number of frames, and N is the number of frequency bins.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
| 168 |
+
"""
|
| 169 |
+
B, L, N = X.shape
|
| 170 |
+
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
| 171 |
+
Y[..., :N] = X
|
| 172 |
+
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
| 173 |
+
y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
|
| 174 |
+
y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
|
| 175 |
+
result = y * self.window.expand(y.shape)
|
| 176 |
+
output_size = (1, (L + 1) * N)
|
| 177 |
+
audio = torch.nn.functional.fold(
|
| 178 |
+
result.transpose(1, 2),
|
| 179 |
+
output_size=output_size,
|
| 180 |
+
kernel_size=(1, self.frame_len),
|
| 181 |
+
stride=(1, self.frame_len // 2),
|
| 182 |
+
)[:, 0, 0, :]
|
| 183 |
+
|
| 184 |
+
if self.padding == "center":
|
| 185 |
+
pad = self.frame_len // 2
|
| 186 |
+
elif self.padding == "same":
|
| 187 |
+
pad = self.frame_len // 4
|
| 188 |
+
else:
|
| 189 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 190 |
+
|
| 191 |
+
audio = audio[:, pad:-pad]
|
| 192 |
+
return audio
|
codec/scripts/compare_codecs.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Optional, Any, Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torchaudio import load as ta_load
|
| 12 |
+
from torchaudio.functional import resample as ta_resample
|
| 13 |
+
import torchaudio
|
| 14 |
+
|
| 15 |
+
# Your libs
|
| 16 |
+
from zcodec.models import WavVAE, ZFlowAutoEncoder
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# -------------------------
|
| 20 |
+
# Data structures
|
| 21 |
+
# -------------------------
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class DecodeParams:
|
| 26 |
+
num_steps: int = 10
|
| 27 |
+
cfg: float = 2.0
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class ModelPairSpec:
|
| 32 |
+
name: str
|
| 33 |
+
wavvae_dir: str
|
| 34 |
+
zflowae_dir: str
|
| 35 |
+
decode: DecodeParams
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# -------------------------
|
| 39 |
+
# Utilities
|
| 40 |
+
# -------------------------
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]:
|
| 44 |
+
if path.is_file():
|
| 45 |
+
try:
|
| 46 |
+
with path.open("r", encoding="utf-8") as f:
|
| 47 |
+
return json.load(f)
|
| 48 |
+
except Exception:
|
| 49 |
+
return None
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def read_config_any(checkpoint_dir: str) -> Dict[str, Any]:
|
| 54 |
+
"""
|
| 55 |
+
Try to read config.json (or a few common fallbacks) from a checkpoint dir.
|
| 56 |
+
Returns {} if nothing could be parsed.
|
| 57 |
+
"""
|
| 58 |
+
cand = [
|
| 59 |
+
Path(checkpoint_dir) / "config.json",
|
| 60 |
+
Path(checkpoint_dir)
|
| 61 |
+
/ "config.yaml", # won't parse yaml here, we only display path
|
| 62 |
+
Path(checkpoint_dir) / "model_config.json",
|
| 63 |
+
]
|
| 64 |
+
for p in cand:
|
| 65 |
+
if p.exists():
|
| 66 |
+
if p.suffix == ".json":
|
| 67 |
+
j = load_json_if_exists(p)
|
| 68 |
+
if j is not None:
|
| 69 |
+
return j
|
| 70 |
+
else:
|
| 71 |
+
# For YAML or unknown, just show filename rather than failing
|
| 72 |
+
return {"_config_file": str(p)}
|
| 73 |
+
return {}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def sanitize_name(s: str) -> str:
|
| 77 |
+
return "".join(c if c.isalnum() or c in "-_." else "_" for c in s)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def ensure_mono_and_resample(
|
| 81 |
+
wav: torch.Tensor, sr: int, target_sr: int
|
| 82 |
+
) -> Tuple[torch.Tensor, int]:
|
| 83 |
+
"""
|
| 84 |
+
wav: (channels, samples)
|
| 85 |
+
returns mono float32 in [-1,1], resampled to target_sr
|
| 86 |
+
"""
|
| 87 |
+
if wav.ndim != 2:
|
| 88 |
+
raise ValueError(f"Expected 2D waveform (C, T), got shape {tuple(wav.shape)}")
|
| 89 |
+
# to mono
|
| 90 |
+
if wav.size(0) > 1:
|
| 91 |
+
wav = wav.mean(dim=0, keepdim=True)
|
| 92 |
+
# resample if needed
|
| 93 |
+
if sr != target_sr:
|
| 94 |
+
wav = ta_resample(wav, sr, target_sr)
|
| 95 |
+
sr = target_sr
|
| 96 |
+
return wav.to(torch.float32), sr
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def save_wav(path: Path, wav: torch.Tensor, sr: int):
|
| 100 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 101 |
+
# (C, T)
|
| 102 |
+
if wav.ndim == 1:
|
| 103 |
+
wav = wav.unsqueeze(0)
|
| 104 |
+
# Clamp to [-1,1]
|
| 105 |
+
wav = wav.clamp(-1, 1).contiguous().cpu()
|
| 106 |
+
torchaudio.save(
|
| 107 |
+
str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# -------------------------
|
| 112 |
+
# Core inference
|
| 113 |
+
# -------------------------
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@torch.inference_mode()
|
| 117 |
+
def reconstruct_full_pipeline(
|
| 118 |
+
wav_mono: torch.Tensor,
|
| 119 |
+
sr: int,
|
| 120 |
+
wavvae: WavVAE,
|
| 121 |
+
zflowae: ZFlowAutoEncoder,
|
| 122 |
+
decode_params: DecodeParams,
|
| 123 |
+
device: str,
|
| 124 |
+
) -> torch.Tensor:
|
| 125 |
+
"""
|
| 126 |
+
Full path: audio -> WavVAE.encode -> ZFlowAE.encode -> ZFlowAE.decode -> WavVAE.decode -> audio_hat
|
| 127 |
+
"""
|
| 128 |
+
wav_mono = wav_mono.to(device)
|
| 129 |
+
# WavVAE expects (B, C, T); assume C=1
|
| 130 |
+
x = wav_mono.unsqueeze(0) # (1, 1, T)
|
| 131 |
+
# Encode to high-framerate latents
|
| 132 |
+
z = wavvae.encode(x)
|
| 133 |
+
# Compress latents
|
| 134 |
+
y = zflowae.encode(z)
|
| 135 |
+
# Decompress
|
| 136 |
+
z_hat = zflowae.decode(y, num_steps=decode_params.num_steps, cfg=decode_params.cfg)
|
| 137 |
+
# Decode to waveform
|
| 138 |
+
wav_hat = wavvae.decode(z_hat) # (1, 1, T)
|
| 139 |
+
# Return mono 1D
|
| 140 |
+
return wav_hat.squeeze(0).squeeze(0).detach()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def load_model_pair(spec: ModelPairSpec, device: str):
|
| 144 |
+
wavvae = WavVAE.from_pretrained_local(spec.wavvae_dir).to(device)
|
| 145 |
+
zflowae = ZFlowAutoEncoder.from_pretrained_local(spec.zflowae_dir).to(device)
|
| 146 |
+
# try to get sampling rate from WavVAE
|
| 147 |
+
target_sr = getattr(wavvae, "sampling_rate", None)
|
| 148 |
+
if target_sr is None:
|
| 149 |
+
# reasonable fallback
|
| 150 |
+
target_sr = 24000
|
| 151 |
+
return wavvae, zflowae, int(target_sr)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def parse_manifest(path: str) -> List[ModelPairSpec]:
|
| 155 |
+
"""
|
| 156 |
+
Manifest format (JSON list):
|
| 157 |
+
[
|
| 158 |
+
{
|
| 159 |
+
"name": "zdim32x8",
|
| 160 |
+
"wavvae": "/path/to/WavVAE_framerate100_zdim32/",
|
| 161 |
+
"zflowae": "/path/to/ZFlowAutoEncoder_stride4_zdim32_vae8_.../",
|
| 162 |
+
"decode": {"num_steps": 10, "cfg": 2.0}
|
| 163 |
+
}
|
| 164 |
+
]
|
| 165 |
+
"""
|
| 166 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 167 |
+
raw = json.load(f)
|
| 168 |
+
out: List[ModelPairSpec] = []
|
| 169 |
+
for item in raw:
|
| 170 |
+
name = item["name"]
|
| 171 |
+
wavvae_dir = item["wavvae"]
|
| 172 |
+
zflowae_dir = item["zflowae"]
|
| 173 |
+
d = item.get("decode", {})
|
| 174 |
+
out.append(
|
| 175 |
+
ModelPairSpec(
|
| 176 |
+
name=name,
|
| 177 |
+
wavvae_dir=wavvae_dir,
|
| 178 |
+
zflowae_dir=zflowae_dir,
|
| 179 |
+
decode=DecodeParams(
|
| 180 |
+
num_steps=int(d.get("num_steps", 10)),
|
| 181 |
+
cfg=float(d.get("cfg", 2.0)),
|
| 182 |
+
),
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
+
return out
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
# -------------------------
|
| 189 |
+
# HTML generation
|
| 190 |
+
# -------------------------
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def html_escape(s: str) -> str:
|
| 194 |
+
return (
|
| 195 |
+
s.replace("&", "&")
|
| 196 |
+
.replace("<", "<")
|
| 197 |
+
.replace(">", ">")
|
| 198 |
+
.replace('"', """)
|
| 199 |
+
.replace("'", "'")
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def make_html(
|
| 204 |
+
output_dir: Path,
|
| 205 |
+
audio_files: List[Path],
|
| 206 |
+
models: List[ModelPairSpec],
|
| 207 |
+
sr_by_model: Dict[str, int],
|
| 208 |
+
wavvae_cfg: Dict[str, Dict[str, Any]],
|
| 209 |
+
zflow_cfg: Dict[str, Dict[str, Any]],
|
| 210 |
+
) -> str:
|
| 211 |
+
"""
|
| 212 |
+
Build a static HTML page with a table:
|
| 213 |
+
Row = input audio file
|
| 214 |
+
Col 1 = Original
|
| 215 |
+
Col 2..N = each model reconstruction
|
| 216 |
+
Also shows minimal model config info above the table.
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
def player(src_rel: str, controls: bool = True) -> str:
|
| 220 |
+
return f'<audio {"controls" if controls else ""} preload="none" src="{html_escape(src_rel)}"></audio>'
|
| 221 |
+
|
| 222 |
+
# Model cards
|
| 223 |
+
model_cards = []
|
| 224 |
+
for spec in models:
|
| 225 |
+
wcfg = wavvae_cfg.get(spec.name, {})
|
| 226 |
+
zcfg = zflow_cfg.get(spec.name, {})
|
| 227 |
+
w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[
|
| 228 |
+
:1200
|
| 229 |
+
]
|
| 230 |
+
z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[
|
| 231 |
+
:1200
|
| 232 |
+
]
|
| 233 |
+
card = f"""
|
| 234 |
+
<div class="model-card">
|
| 235 |
+
<h3>{html_escape(spec.name)}</h3>
|
| 236 |
+
<p><b>Sample rate</b>: {sr_by_model.get(spec.name, "N/A")} Hz</p>
|
| 237 |
+
<details>
|
| 238 |
+
<summary>WavVAE config</summary>
|
| 239 |
+
<pre>{html_escape(w_short)}</pre>
|
| 240 |
+
</details>
|
| 241 |
+
<details>
|
| 242 |
+
<summary>ZFlowAE config</summary>
|
| 243 |
+
<pre>{html_escape(z_short)}</pre>
|
| 244 |
+
</details>
|
| 245 |
+
<p><b>Decode</b>: num_steps={spec.decode.num_steps}, cfg={spec.decode.cfg}</p>
|
| 246 |
+
</div>
|
| 247 |
+
"""
|
| 248 |
+
model_cards.append(card)
|
| 249 |
+
|
| 250 |
+
# Table header
|
| 251 |
+
th = "<th>Input</th><th>Original</th>" + "".join(
|
| 252 |
+
f"<th>{html_escape(m.name)}</th>" for m in models
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Rows
|
| 256 |
+
rows = []
|
| 257 |
+
for af in audio_files:
|
| 258 |
+
base = af.stem
|
| 259 |
+
orig_rel = f"original/{html_escape(af.name)}"
|
| 260 |
+
tds = [f"<td>{html_escape(base)}</td>", f"<td>{player(orig_rel)}</td>"]
|
| 261 |
+
for m in models:
|
| 262 |
+
rec_rel = f"recon/{html_escape(m.name)}/{html_escape(base)}.wav"
|
| 263 |
+
tds.append(f"<td>{player(rec_rel)}</td>")
|
| 264 |
+
rows.append("<tr>" + "".join(tds) + "</tr>")
|
| 265 |
+
|
| 266 |
+
# Simple CSS to keep it clean
|
| 267 |
+
css = """
|
| 268 |
+
body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; }
|
| 269 |
+
h1 { margin-bottom: 0.2rem; }
|
| 270 |
+
.cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; }
|
| 271 |
+
.model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; }
|
| 272 |
+
table { border-collapse: collapse; width: 100%; }
|
| 273 |
+
th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; }
|
| 274 |
+
th { background: #fafafa; position: sticky; top: 0; }
|
| 275 |
+
audio { width: 260px; }
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
html = f"""<!doctype html>
|
| 279 |
+
<html>
|
| 280 |
+
<head>
|
| 281 |
+
<meta charset="utf-8"/>
|
| 282 |
+
<title>Codec Comparison</title>
|
| 283 |
+
<style>{css}</style>
|
| 284 |
+
</head>
|
| 285 |
+
<body>
|
| 286 |
+
<h1>Codec Comparison</h1>
|
| 287 |
+
<p>This page compares reconstructions across model checkpoints. Click play in each cell.</p>
|
| 288 |
+
|
| 289 |
+
<h2>Models</h2>
|
| 290 |
+
<div class="cards">
|
| 291 |
+
{"".join(model_cards)}
|
| 292 |
+
</div>
|
| 293 |
+
|
| 294 |
+
<h2>Audio</h2>
|
| 295 |
+
<table>
|
| 296 |
+
<thead><tr>{th}</tr></thead>
|
| 297 |
+
<tbody>
|
| 298 |
+
{"".join(rows)}
|
| 299 |
+
</tbody>
|
| 300 |
+
</table>
|
| 301 |
+
</body>
|
| 302 |
+
</html>
|
| 303 |
+
"""
|
| 304 |
+
out = output_dir / "index.html"
|
| 305 |
+
out.write_text(html, encoding="utf-8")
|
| 306 |
+
return str(out)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# -------------------------
|
| 310 |
+
# Main
|
| 311 |
+
# -------------------------
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def main():
|
| 315 |
+
p = argparse.ArgumentParser(
|
| 316 |
+
description="Compare Z-Codec configurations and generate a static HTML page."
|
| 317 |
+
)
|
| 318 |
+
p.add_argument(
|
| 319 |
+
"--manifest",
|
| 320 |
+
type=str,
|
| 321 |
+
required=True,
|
| 322 |
+
help="JSON file listing model pairs. See docstring in parse_manifest().",
|
| 323 |
+
)
|
| 324 |
+
p.add_argument(
|
| 325 |
+
"--audio", type=str, nargs="+", required=True, help="List of input audio files."
|
| 326 |
+
)
|
| 327 |
+
p.add_argument(
|
| 328 |
+
"--out",
|
| 329 |
+
type=str,
|
| 330 |
+
default="codec_compare_out",
|
| 331 |
+
help="Output directory for reconstructions and HTML.",
|
| 332 |
+
)
|
| 333 |
+
p.add_argument(
|
| 334 |
+
"--device",
|
| 335 |
+
type=str,
|
| 336 |
+
default="cuda",
|
| 337 |
+
help="Device to run inference on (cuda or cpu).",
|
| 338 |
+
)
|
| 339 |
+
p.add_argument(
|
| 340 |
+
"--force",
|
| 341 |
+
action="store_true",
|
| 342 |
+
help="Recompute even if target wav already exists.",
|
| 343 |
+
)
|
| 344 |
+
args = p.parse_args()
|
| 345 |
+
|
| 346 |
+
device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"
|
| 347 |
+
out_dir = Path(args.out)
|
| 348 |
+
orig_dir = out_dir / "original"
|
| 349 |
+
recon_dir = out_dir / "recon"
|
| 350 |
+
orig_dir.mkdir(parents=True, exist_ok=True)
|
| 351 |
+
recon_dir.mkdir(parents=True, exist_ok=True)
|
| 352 |
+
|
| 353 |
+
# Parse models
|
| 354 |
+
specs = parse_manifest(args.manifest)
|
| 355 |
+
if not specs:
|
| 356 |
+
print("No models in manifest.", file=sys.stderr)
|
| 357 |
+
sys.exit(1)
|
| 358 |
+
|
| 359 |
+
# Load models
|
| 360 |
+
loaded: Dict[str, Dict[str, Any]] = {}
|
| 361 |
+
sr_by_model: Dict[str, int] = {}
|
| 362 |
+
wavvae_cfg: Dict[str, Dict[str, Any]] = {}
|
| 363 |
+
zflow_cfg: Dict[str, Dict[str, Any]] = {}
|
| 364 |
+
|
| 365 |
+
for spec in specs:
|
| 366 |
+
print(f"[Load] {spec.name}")
|
| 367 |
+
wavvae, zflowae, target_sr = load_model_pair(spec, device)
|
| 368 |
+
loaded[spec.name] = {"wavvae": wavvae, "zflowae": zflowae, "sr": target_sr}
|
| 369 |
+
sr_by_model[spec.name] = target_sr
|
| 370 |
+
wavvae_cfg[spec.name] = read_config_any(spec.wavvae_dir)
|
| 371 |
+
zflow_cfg[spec.name] = read_config_any(spec.zflowae_dir)
|
| 372 |
+
|
| 373 |
+
# Process audio files
|
| 374 |
+
audio_files = [Path(a) for a in args.audio]
|
| 375 |
+
for af in audio_files:
|
| 376 |
+
if not af.exists():
|
| 377 |
+
print(f"[Skip] Missing: {af}", file=sys.stderr)
|
| 378 |
+
continue
|
| 379 |
+
|
| 380 |
+
# copy original (resampled per model? We'll store original as-is)
|
| 381 |
+
# Just place the original file for direct playback
|
| 382 |
+
# If it's not wav, we still copy a WAV version for compatibility.
|
| 383 |
+
# But simplest: if not wav, we re-save as wav 16-bit for the page.
|
| 384 |
+
out_orig = orig_dir / af.name
|
| 385 |
+
if args.force or not out_orig.exists():
|
| 386 |
+
# Load and resave as wav to ensure browser-compat
|
| 387 |
+
wav, sr = ta_load(str(af))
|
| 388 |
+
# make it mono for fair listening
|
| 389 |
+
wav_mono, sr = ensure_mono_and_resample(wav, sr, sr)
|
| 390 |
+
save_wav(out_orig.with_suffix(".wav"), wav_mono, sr)
|
| 391 |
+
# keep the name consistent in the HTML (use .wav)
|
| 392 |
+
af = af.with_suffix(".wav")
|
| 393 |
+
# rename saved file to matched name
|
| 394 |
+
if out_orig.suffix != ".wav":
|
| 395 |
+
# Clean: ensure HTML references the .wav filename
|
| 396 |
+
out_orig = out_orig.with_suffix(".wav")
|
| 397 |
+
|
| 398 |
+
# For each model, run full pipeline and save
|
| 399 |
+
base = af.stem
|
| 400 |
+
# Re-load from disk to ensure consistent start-point (original .wav in out folder)
|
| 401 |
+
wav0, sr0 = ta_load(str(out_orig if out_orig.exists() else orig_dir / af.name))
|
| 402 |
+
# Make mono only once; resample per-model to each target SR
|
| 403 |
+
if wav0.size(0) > 1:
|
| 404 |
+
wav0 = wav0.mean(dim=0, keepdim=True)
|
| 405 |
+
|
| 406 |
+
for spec in specs:
|
| 407 |
+
mname = spec.name
|
| 408 |
+
target_sr = sr_by_model[mname]
|
| 409 |
+
# resample to model's SR
|
| 410 |
+
if sr0 != target_sr:
|
| 411 |
+
wav_mono = ta_resample(wav0, sr0, target_sr)
|
| 412 |
+
else:
|
| 413 |
+
wav_mono = wav0
|
| 414 |
+
|
| 415 |
+
# reconstruct
|
| 416 |
+
out_path = recon_dir / mname / f"{sanitize_name(base)}.wav"
|
| 417 |
+
if args.force or not out_path.exists():
|
| 418 |
+
print(f"[Reconstruct] {mname} ← {base}")
|
| 419 |
+
wavvae = loaded[mname]["wavvae"]
|
| 420 |
+
zflowae = loaded[mname]["zflowae"]
|
| 421 |
+
wav_hat = reconstruct_full_pipeline(
|
| 422 |
+
wav_mono, target_sr, wavvae, zflowae, spec.decode, device
|
| 423 |
+
)
|
| 424 |
+
save_wav(out_path, wav_hat.unsqueeze(0), target_sr)
|
| 425 |
+
|
| 426 |
+
# Build HTML
|
| 427 |
+
# Rebuild the list of files actually present in original/ (use .wav names)
|
| 428 |
+
actual_audio = sorted([p for p in (orig_dir).glob("*.wav")])
|
| 429 |
+
html_path = make_html(
|
| 430 |
+
out_dir,
|
| 431 |
+
actual_audio,
|
| 432 |
+
specs,
|
| 433 |
+
sr_by_model,
|
| 434 |
+
wavvae_cfg,
|
| 435 |
+
zflow_cfg,
|
| 436 |
+
)
|
| 437 |
+
print(f"\nDone. Open: {html_path}")
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
if __name__ == "__main__":
|
| 441 |
+
main()
|
codec/scripts/compare_wavvae.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import sys
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
from torchaudio import load as ta_load
|
| 12 |
+
from torchaudio.functional import resample as ta_resample
|
| 13 |
+
from zcodec.models import WavVAE
|
| 14 |
+
|
| 15 |
+
# -------------------------
|
| 16 |
+
# Data structures
|
| 17 |
+
# -------------------------
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class WavVaeSpec:
|
| 22 |
+
name: str
|
| 23 |
+
wavvae_dir: str
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# -------------------------
|
| 27 |
+
# Utilities
|
| 28 |
+
# -------------------------
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]:
|
| 32 |
+
if path.is_file():
|
| 33 |
+
try:
|
| 34 |
+
return json.load(path.open("r", encoding="utf-8"))
|
| 35 |
+
except Exception:
|
| 36 |
+
return None
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def read_config_any(checkpoint_dir: str) -> Dict[str, Any]:
|
| 41 |
+
cand = [
|
| 42 |
+
Path(checkpoint_dir) / "config.json",
|
| 43 |
+
Path(checkpoint_dir) / "model_config.json",
|
| 44 |
+
Path(checkpoint_dir) / "config.yaml", # shown as path only
|
| 45 |
+
]
|
| 46 |
+
for p in cand:
|
| 47 |
+
if p.exists():
|
| 48 |
+
if p.suffix == ".json":
|
| 49 |
+
j = load_json_if_exists(p)
|
| 50 |
+
if j is not None:
|
| 51 |
+
return j
|
| 52 |
+
else:
|
| 53 |
+
return {"_config_file": str(p)}
|
| 54 |
+
return {}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def sanitize_name(s: str) -> str:
|
| 58 |
+
return "".join(c if c.isalnum() or c in "-_." else "_" for c in s)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def ensure_mono_and_resample(
|
| 62 |
+
wav: torch.Tensor, sr: int, target_sr: int
|
| 63 |
+
) -> Tuple[torch.Tensor, int]:
|
| 64 |
+
if wav.ndim != 2:
|
| 65 |
+
raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}")
|
| 66 |
+
if wav.size(0) > 1:
|
| 67 |
+
wav = wav.mean(dim=0, keepdim=True)
|
| 68 |
+
if sr != target_sr:
|
| 69 |
+
wav = ta_resample(wav, sr, target_sr)
|
| 70 |
+
sr = target_sr
|
| 71 |
+
return wav.to(torch.float32), sr
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def save_wav(path: Path, wav: torch.Tensor, sr: int):
|
| 75 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 76 |
+
if wav.ndim == 1:
|
| 77 |
+
wav = wav.unsqueeze(0)
|
| 78 |
+
wav = wav.clamp(-1, 1).contiguous().cpu()
|
| 79 |
+
torchaudio.save(
|
| 80 |
+
str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def read_audio_manifest(txt_path: str) -> List[Path]:
|
| 85 |
+
lines = Path(txt_path).read_text(encoding="utf-8").splitlines()
|
| 86 |
+
files = [
|
| 87 |
+
Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#")
|
| 88 |
+
]
|
| 89 |
+
return files
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def html_escape(s: str) -> str:
|
| 93 |
+
return (
|
| 94 |
+
s.replace("&", "&")
|
| 95 |
+
.replace("<", "<")
|
| 96 |
+
.replace(">", ">")
|
| 97 |
+
.replace('"', """)
|
| 98 |
+
.replace("'", "'")
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def make_html(
|
| 103 |
+
output_dir: Path,
|
| 104 |
+
audio_files: List[Path],
|
| 105 |
+
specs: List[WavVaeSpec],
|
| 106 |
+
sr_by_model: Dict[str, int],
|
| 107 |
+
wavvae_cfg: Dict[str, Dict[str, Any]],
|
| 108 |
+
) -> str:
|
| 109 |
+
def player(src_rel: str) -> str:
|
| 110 |
+
return f'<audio controls preload="none" src="{html_escape(src_rel)}"></audio>'
|
| 111 |
+
|
| 112 |
+
# cards
|
| 113 |
+
cards = []
|
| 114 |
+
for s in specs:
|
| 115 |
+
cfg = wavvae_cfg.get(s.name, {})
|
| 116 |
+
cfg_short = json.dumps(cfg if cfg else {"_": "no JSON config found"}, indent=2)[
|
| 117 |
+
:1200
|
| 118 |
+
]
|
| 119 |
+
card = f"""
|
| 120 |
+
<div class="model-card">
|
| 121 |
+
<h3>{html_escape(s.name)}</h3>
|
| 122 |
+
<p><b>Sample rate</b>: {sr_by_model.get(s.name, "N/A")} Hz</p>
|
| 123 |
+
<details><summary>WavVAE config</summary><pre>{html_escape(cfg_short)}</pre></details>
|
| 124 |
+
</div>
|
| 125 |
+
"""
|
| 126 |
+
cards.append(card)
|
| 127 |
+
|
| 128 |
+
css = """
|
| 129 |
+
body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; }
|
| 130 |
+
.cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; }
|
| 131 |
+
.model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; }
|
| 132 |
+
table { border-collapse: collapse; width: 100%; }
|
| 133 |
+
th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; }
|
| 134 |
+
th { background: #fafafa; position: sticky; top: 0; }
|
| 135 |
+
audio { width: 260px; }
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
th = "<th>Input</th><th>Original</th>" + "".join(
|
| 139 |
+
f"<th>{html_escape(s.name)}</th>" for s in specs
|
| 140 |
+
)
|
| 141 |
+
rows = []
|
| 142 |
+
for af in audio_files:
|
| 143 |
+
base = af.stem
|
| 144 |
+
orig_rel = f"original/{html_escape(af.name)}"
|
| 145 |
+
tds = [f"<td>{html_escape(base)}</td>", f"<td>{player(orig_rel)}</td>"]
|
| 146 |
+
for s in specs:
|
| 147 |
+
rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav"
|
| 148 |
+
tds.append(f"<td>{player(rec_rel)}</td>")
|
| 149 |
+
rows.append("<tr>" + "".join(tds) + "</tr>")
|
| 150 |
+
|
| 151 |
+
html = f"""<!doctype html>
|
| 152 |
+
<html>
|
| 153 |
+
<head><meta charset="utf-8"/><title>WavVAE Comparison</title><style>{css}</style></head>
|
| 154 |
+
<body>
|
| 155 |
+
<h1>WavVAE Comparison</h1>
|
| 156 |
+
<div class="cards">{"".join(cards)}</div>
|
| 157 |
+
<table>
|
| 158 |
+
<thead><tr>{th}</tr></thead>
|
| 159 |
+
<tbody>{"".join(rows)}</tbody>
|
| 160 |
+
</table>
|
| 161 |
+
</body>
|
| 162 |
+
</html>
|
| 163 |
+
"""
|
| 164 |
+
out = output_dir / "index.html"
|
| 165 |
+
out.write_text(html, encoding="utf-8")
|
| 166 |
+
return str(out)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# -------------------------
|
| 170 |
+
# Core
|
| 171 |
+
# -------------------------
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
@torch.inference_mode()
|
| 175 |
+
def reconstruct_wavvae(
|
| 176 |
+
wav_mono: torch.Tensor, wavvae: WavVAE, device: str
|
| 177 |
+
) -> torch.Tensor:
|
| 178 |
+
x = wav_mono.to(device) # (1,T)
|
| 179 |
+
z = wavvae.encode(x)
|
| 180 |
+
wav_hat = wavvae.decode(z) # (1,1,T)
|
| 181 |
+
return wav_hat.squeeze(0).squeeze(0).detach()
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def parse_models_manifest(path: str) -> List[WavVaeSpec]:
|
| 185 |
+
"""
|
| 186 |
+
JSON list of:
|
| 187 |
+
{"name": "...", "wavvae": "/path/to/WavVAE_dir"}
|
| 188 |
+
"""
|
| 189 |
+
raw = json.loads(Path(path).read_text(encoding="utf-8"))
|
| 190 |
+
specs = []
|
| 191 |
+
for it in raw:
|
| 192 |
+
specs.append(WavVaeSpec(name=it["name"], wavvae_dir=it["wavvae"]))
|
| 193 |
+
return specs
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def main():
|
| 197 |
+
ap = argparse.ArgumentParser(
|
| 198 |
+
description="Compare WavVAE checkpoints and generate a static HTML page."
|
| 199 |
+
)
|
| 200 |
+
ap.add_argument("--models", required=True, help="JSON manifest of WavVAE models.")
|
| 201 |
+
ap.add_argument(
|
| 202 |
+
"--audio_manifest", required=True, help="TXT file: one audio path per line."
|
| 203 |
+
)
|
| 204 |
+
ap.add_argument("--out", default="compare_wavvae_out")
|
| 205 |
+
ap.add_argument("--device", default="cuda")
|
| 206 |
+
ap.add_argument("--force", action="store_true")
|
| 207 |
+
args = ap.parse_args()
|
| 208 |
+
|
| 209 |
+
device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"
|
| 210 |
+
out_dir = Path(args.out)
|
| 211 |
+
(out_dir / "original").mkdir(parents=True, exist_ok=True)
|
| 212 |
+
recon_dir = out_dir / "recon"
|
| 213 |
+
recon_dir.mkdir(parents=True, exist_ok=True)
|
| 214 |
+
|
| 215 |
+
specs = parse_models_manifest(args.models)
|
| 216 |
+
if not specs:
|
| 217 |
+
print("No models.", file=sys.stderr)
|
| 218 |
+
sys.exit(1)
|
| 219 |
+
|
| 220 |
+
# load models
|
| 221 |
+
wavvae_by_name: Dict[str, WavVAE] = {}
|
| 222 |
+
sr_by_model: Dict[str, int] = {}
|
| 223 |
+
wavvae_cfg: Dict[str, Dict[str, Any]] = {}
|
| 224 |
+
for s in specs:
|
| 225 |
+
print(f"[Load] {s.name}")
|
| 226 |
+
w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device)
|
| 227 |
+
wavvae_by_name[s.name] = w
|
| 228 |
+
sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000))
|
| 229 |
+
wavvae_cfg[s.name] = read_config_any(s.wavvae_dir)
|
| 230 |
+
|
| 231 |
+
audio_paths = read_audio_manifest(args.audio_manifest)
|
| 232 |
+
# normalize originals to wav+mono (browser-friendly); keep native sr for original column
|
| 233 |
+
actual_audio = []
|
| 234 |
+
for ap in audio_paths:
|
| 235 |
+
if not ap.exists():
|
| 236 |
+
print(f"[Skip missing] {ap}", file=sys.stderr)
|
| 237 |
+
continue
|
| 238 |
+
wav, sr = ta_load(str(ap))
|
| 239 |
+
wav_mono, sr = ensure_mono_and_resample(wav, sr, sr)
|
| 240 |
+
out_orig = out_dir / "original" / (ap.stem + ".wav")
|
| 241 |
+
if args.force or not out_orig.exists():
|
| 242 |
+
save_wav(out_orig, wav_mono, sr)
|
| 243 |
+
actual_audio.append(out_orig)
|
| 244 |
+
|
| 245 |
+
# recon per model
|
| 246 |
+
for out_orig in actual_audio:
|
| 247 |
+
wav0, sr0 = ta_load(str(out_orig))
|
| 248 |
+
if wav0.size(0) > 1:
|
| 249 |
+
wav0 = wav0.mean(dim=0, keepdim=True)
|
| 250 |
+
for s in specs:
|
| 251 |
+
target_sr = sr_by_model[s.name]
|
| 252 |
+
wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0
|
| 253 |
+
out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav"
|
| 254 |
+
if args.force or not out_path.exists():
|
| 255 |
+
print(f"[Reconstruct] {s.name} ← {out_orig.name}")
|
| 256 |
+
wav_hat = reconstruct_wavvae(wav_in, wavvae_by_name[s.name], device)
|
| 257 |
+
save_wav(out_path, wav_hat, target_sr)
|
| 258 |
+
|
| 259 |
+
html_path = make_html(out_dir, actual_audio, specs, sr_by_model, wavvae_cfg)
|
| 260 |
+
print(f"Done. Open: {html_path}")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
main()
|
codec/scripts/compare_zcodec.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import sys
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torchaudio
|
| 11 |
+
from torchaudio import load as ta_load
|
| 12 |
+
from torchaudio.functional import resample as ta_resample
|
| 13 |
+
from zcodec.models import WavVAE, ZFlowAutoEncoder
|
| 14 |
+
|
| 15 |
+
# -------------------------
|
| 16 |
+
# Data structures
|
| 17 |
+
# -------------------------
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class DecodeParams:
|
| 22 |
+
num_steps: int = 10
|
| 23 |
+
cfg: float = 2.0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class StackSpec:
|
| 28 |
+
name: str
|
| 29 |
+
wavvae_dir: str
|
| 30 |
+
zflowae_dir: str
|
| 31 |
+
decode: DecodeParams
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# -------------------------
|
| 35 |
+
# Utilities (same helpers)
|
| 36 |
+
# -------------------------
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_json_if_exists(path: Path):
|
| 40 |
+
if path.is_file():
|
| 41 |
+
try:
|
| 42 |
+
return json.load(path.open("r", encoding="utf-8"))
|
| 43 |
+
except Exception:
|
| 44 |
+
return None
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def read_config_any(checkpoint_dir: str) -> Dict[str, Any]:
|
| 49 |
+
cand = [
|
| 50 |
+
Path(checkpoint_dir) / "config.json",
|
| 51 |
+
Path(checkpoint_dir) / "model_config.json",
|
| 52 |
+
Path(checkpoint_dir) / "config.yaml",
|
| 53 |
+
]
|
| 54 |
+
for p in cand:
|
| 55 |
+
if p.exists():
|
| 56 |
+
if p.suffix == ".json":
|
| 57 |
+
j = load_json_if_exists(p)
|
| 58 |
+
if j is not None:
|
| 59 |
+
return j
|
| 60 |
+
else:
|
| 61 |
+
return {"_config_file": str(p)}
|
| 62 |
+
return {}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def sanitize_name(s: str) -> str:
|
| 66 |
+
return "".join(c if c.isalnum() or c in "-_." else "_" for c in s)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def ensure_mono_and_resample(
|
| 70 |
+
wav: torch.Tensor, sr: int, target_sr: int
|
| 71 |
+
) -> Tuple[torch.Tensor, int]:
|
| 72 |
+
if wav.ndim != 2:
|
| 73 |
+
raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}")
|
| 74 |
+
if wav.size(0) > 1:
|
| 75 |
+
wav = wav.mean(dim=0, keepdim=True)
|
| 76 |
+
if sr != target_sr:
|
| 77 |
+
wav = ta_resample(wav, sr, target_sr)
|
| 78 |
+
sr = target_sr
|
| 79 |
+
return wav.to(torch.float32), sr
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def save_wav(path: Path, wav: torch.Tensor, sr: int):
|
| 83 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 84 |
+
if wav.ndim == 1:
|
| 85 |
+
wav = wav.unsqueeze(0)
|
| 86 |
+
wav = wav.clamp(-1, 1).contiguous().cpu()
|
| 87 |
+
torchaudio.save(
|
| 88 |
+
str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def read_audio_manifest(txt_path: str) -> List[Path]:
|
| 93 |
+
lines = Path(txt_path).read_text(encoding="utf-8").splitlines()
|
| 94 |
+
return [
|
| 95 |
+
Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#")
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def html_escape(s: str) -> str:
|
| 100 |
+
return (
|
| 101 |
+
s.replace("&", "&")
|
| 102 |
+
.replace("<", "<")
|
| 103 |
+
.replace(">", ">")
|
| 104 |
+
.replace('"', """)
|
| 105 |
+
.replace("'", "'")
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def make_html(
|
| 110 |
+
output_dir: Path,
|
| 111 |
+
audio_files: List[Path],
|
| 112 |
+
specs: List[StackSpec],
|
| 113 |
+
sr_by_model: Dict[str, int],
|
| 114 |
+
wavvae_cfg: Dict[str, Dict[str, Any]],
|
| 115 |
+
zflow_cfg: Dict[str, Dict[str, Any]],
|
| 116 |
+
) -> str:
|
| 117 |
+
def player(src_rel: str) -> str:
|
| 118 |
+
return f'<audio controls preload="none" src="{html_escape(src_rel)}"></audio>'
|
| 119 |
+
|
| 120 |
+
cards = []
|
| 121 |
+
for s in specs:
|
| 122 |
+
wcfg = wavvae_cfg.get(s.name, {})
|
| 123 |
+
zcfg = zflow_cfg.get(s.name, {})
|
| 124 |
+
w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[
|
| 125 |
+
:1200
|
| 126 |
+
]
|
| 127 |
+
z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[
|
| 128 |
+
:1200
|
| 129 |
+
]
|
| 130 |
+
card = f"""
|
| 131 |
+
<div class="model-card">
|
| 132 |
+
<h3>{html_escape(s.name)}</h3>
|
| 133 |
+
<p><b>Sample rate</b>: {sr_by_model.get(s.name, "N/A")} Hz</p>
|
| 134 |
+
<p><b>Decode</b>: steps={s.decode.num_steps}, cfg={s.decode.cfg}</p>
|
| 135 |
+
<details><summary>WavVAE config</summary><pre>{html_escape(w_short)}</pre></details>
|
| 136 |
+
<details><summary>ZFlowAE config</summary><pre>{html_escape(z_short)}</pre></details>
|
| 137 |
+
</div>
|
| 138 |
+
"""
|
| 139 |
+
cards.append(card)
|
| 140 |
+
|
| 141 |
+
css = """
|
| 142 |
+
body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; }
|
| 143 |
+
.cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; }
|
| 144 |
+
.model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; }
|
| 145 |
+
table { border-collapse: collapse; width: 100%; }
|
| 146 |
+
th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; }
|
| 147 |
+
th { background: #fafafa; position: sticky; top: 0; }
|
| 148 |
+
audio { width: 260px; }
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
th = "<th>Input</th><th>Original</th>" + "".join(
|
| 152 |
+
f"<th>{html_escape(s.name)}</th>" for s in specs
|
| 153 |
+
)
|
| 154 |
+
rows = []
|
| 155 |
+
for af in audio_files:
|
| 156 |
+
base = af.stem
|
| 157 |
+
orig_rel = f"original/{html_escape(af.name)}"
|
| 158 |
+
tds = [f"<td>{html_escape(base)}</td>", f"<td>{player(orig_rel)}</td>"]
|
| 159 |
+
for s in specs:
|
| 160 |
+
rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav"
|
| 161 |
+
tds.append(f"<td>{player(rec_rel)}</td>")
|
| 162 |
+
rows.append("<tr>" + "".join(tds) + "</tr>")
|
| 163 |
+
|
| 164 |
+
html = f"""<!doctype html>
|
| 165 |
+
<html>
|
| 166 |
+
<head><meta charset="utf-8"/><title>Stacked Codec Comparison</title><style>{css}</style></head>
|
| 167 |
+
<body>
|
| 168 |
+
<h1>WavVAE + ZFlowAE Comparison</h1>
|
| 169 |
+
<div class="cards">{"".join(cards)}</div>
|
| 170 |
+
<table>
|
| 171 |
+
<thead><tr>{th}</tr></thead>
|
| 172 |
+
<tbody>{"".join(rows)}</tbody>
|
| 173 |
+
</table>
|
| 174 |
+
</body>
|
| 175 |
+
</html>
|
| 176 |
+
"""
|
| 177 |
+
out = output_dir / "index.html"
|
| 178 |
+
out.write_text(html, encoding="utf-8")
|
| 179 |
+
return str(out)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# -------------------------
|
| 183 |
+
# Core
|
| 184 |
+
# -------------------------
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@torch.inference_mode()
|
| 188 |
+
def reconstruct_stack(
|
| 189 |
+
wav_mono: torch.Tensor,
|
| 190 |
+
wavvae: WavVAE,
|
| 191 |
+
zflow: ZFlowAutoEncoder,
|
| 192 |
+
steps: int,
|
| 193 |
+
cfg: float,
|
| 194 |
+
device: str,
|
| 195 |
+
) -> torch.Tensor:
|
| 196 |
+
x = wav_mono.to(device) # (1,T)
|
| 197 |
+
z = wavvae.encode(x) # high-framerate latents
|
| 198 |
+
y, _ = zflow.encode(z) # compressed latents
|
| 199 |
+
z_hat = zflow.decode(y, num_steps=steps, cfg=cfg)
|
| 200 |
+
wav_hat = wavvae.decode(z_hat) # (1,1,T)
|
| 201 |
+
return wav_hat.squeeze(0).squeeze(0).detach()
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def parse_models_manifest(path: str) -> List[StackSpec]:
|
| 205 |
+
"""
|
| 206 |
+
JSON list of:
|
| 207 |
+
{
|
| 208 |
+
"name": "...",
|
| 209 |
+
"wavvae": "/path/to/WavVAE_dir",
|
| 210 |
+
"zflowae": "/path/to/ZFlowAE_dir",
|
| 211 |
+
"decode": {"num_steps": 10, "cfg": 2.0}
|
| 212 |
+
}
|
| 213 |
+
"""
|
| 214 |
+
raw = json.loads(Path(path).read_text(encoding="utf-8"))
|
| 215 |
+
specs = []
|
| 216 |
+
for it in raw:
|
| 217 |
+
d = it.get("decode", {})
|
| 218 |
+
specs.append(
|
| 219 |
+
StackSpec(
|
| 220 |
+
name=it["name"],
|
| 221 |
+
wavvae_dir=it["wavvae"],
|
| 222 |
+
zflowae_dir=it["zflowae"],
|
| 223 |
+
decode=DecodeParams(
|
| 224 |
+
num_steps=int(d.get("num_steps", 10)), cfg=float(d.get("cfg", 2.0))
|
| 225 |
+
),
|
| 226 |
+
)
|
| 227 |
+
)
|
| 228 |
+
return specs
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def main():
|
| 232 |
+
ap = argparse.ArgumentParser(
|
| 233 |
+
description="Compare WavVAE+ZFlowAE stacks and generate a static HTML page."
|
| 234 |
+
)
|
| 235 |
+
ap.add_argument("--models", required=True, help="JSON manifest of stacks.")
|
| 236 |
+
ap.add_argument(
|
| 237 |
+
"--audio_manifest", required=True, help="TXT file: one audio path per line."
|
| 238 |
+
)
|
| 239 |
+
ap.add_argument("--out", default="compare_stack_out")
|
| 240 |
+
ap.add_argument("--device", default="cuda")
|
| 241 |
+
ap.add_argument("--force", action="store_true")
|
| 242 |
+
args = ap.parse_args()
|
| 243 |
+
|
| 244 |
+
device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"
|
| 245 |
+
out_dir = Path(args.out)
|
| 246 |
+
(out_dir / "original").mkdir(parents=True, exist_ok=True)
|
| 247 |
+
recon_dir = out_dir / "recon"
|
| 248 |
+
recon_dir.mkdir(parents=True, exist_ok=True)
|
| 249 |
+
|
| 250 |
+
specs = parse_models_manifest(args.models)
|
| 251 |
+
if not specs:
|
| 252 |
+
print("No models.", file=sys.stderr)
|
| 253 |
+
sys.exit(1)
|
| 254 |
+
|
| 255 |
+
# load models
|
| 256 |
+
wavvae_by_name: Dict[str, WavVAE] = {}
|
| 257 |
+
zflow_by_name: Dict[str, ZFlowAutoEncoder] = {}
|
| 258 |
+
sr_by_model: Dict[str, int] = {}
|
| 259 |
+
wavvae_cfg: Dict[str, Dict[str, Any]] = {}
|
| 260 |
+
zflow_cfg: Dict[str, Dict[str, Any]] = {}
|
| 261 |
+
for s in specs:
|
| 262 |
+
print(f"[Load] {s.name}")
|
| 263 |
+
w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device)
|
| 264 |
+
z = ZFlowAutoEncoder.from_pretrained_local(s.zflowae_dir).to(device)
|
| 265 |
+
wavvae_by_name[s.name] = w
|
| 266 |
+
zflow_by_name[s.name] = z
|
| 267 |
+
sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000))
|
| 268 |
+
wavvae_cfg[s.name] = read_config_any(s.wavvae_dir)
|
| 269 |
+
zflow_cfg[s.name] = read_config_any(s.zflowae_dir)
|
| 270 |
+
|
| 271 |
+
audio_paths = read_audio_manifest(args.audio_manifest)
|
| 272 |
+
|
| 273 |
+
actual_audio = []
|
| 274 |
+
for ap in audio_paths:
|
| 275 |
+
if not ap.exists():
|
| 276 |
+
print(f"[Skip missing] {ap}", file=sys.stderr)
|
| 277 |
+
continue
|
| 278 |
+
wav, sr = ta_load(str(ap))
|
| 279 |
+
wav_mono, sr = ensure_mono_and_resample(wav, sr, sr)
|
| 280 |
+
out_orig = out_dir / "original" / (ap.stem + ".wav")
|
| 281 |
+
if args.force or not out_orig.exists():
|
| 282 |
+
save_wav(out_orig, wav_mono, sr)
|
| 283 |
+
actual_audio.append(out_orig)
|
| 284 |
+
|
| 285 |
+
for out_orig in actual_audio:
|
| 286 |
+
wav0, sr0 = ta_load(str(out_orig))
|
| 287 |
+
if wav0.size(0) > 1:
|
| 288 |
+
wav0 = wav0.mean(dim=0, keepdim=True)
|
| 289 |
+
for s in specs:
|
| 290 |
+
target_sr = sr_by_model[s.name]
|
| 291 |
+
wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0
|
| 292 |
+
out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav"
|
| 293 |
+
if args.force or not out_path.exists():
|
| 294 |
+
print(f"[Reconstruct] {s.name} ← {out_orig.name}")
|
| 295 |
+
wav_hat = reconstruct_stack(
|
| 296 |
+
wav_in,
|
| 297 |
+
wavvae_by_name[s.name],
|
| 298 |
+
zflow_by_name[s.name],
|
| 299 |
+
s.decode.num_steps,
|
| 300 |
+
s.decode.cfg,
|
| 301 |
+
device,
|
| 302 |
+
)
|
| 303 |
+
save_wav(out_path, wav_hat, target_sr)
|
| 304 |
+
|
| 305 |
+
html_path = make_html(
|
| 306 |
+
out_dir, actual_audio, specs, sr_by_model, wavvae_cfg, zflow_cfg
|
| 307 |
+
)
|
| 308 |
+
print(f"Done. Open: {html_path}")
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
if __name__ == "__main__":
|
| 312 |
+
main()
|
codec/scripts/compute_stats.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from safetensors.torch import safe_open, save_file
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_tensor(path: str, key: str = "embedding") -> torch.Tensor:
|
| 10 |
+
with safe_open(path, framework="pt", device="cpu") as f:
|
| 11 |
+
return f.get_tensor(key)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def compute_global_stats(file_list, key="embedding", length_weighted=True):
|
| 15 |
+
sum_all = None
|
| 16 |
+
sum_sq_all = None
|
| 17 |
+
count_all = 0
|
| 18 |
+
|
| 19 |
+
for path in tqdm(file_list, desc="Computing stats"):
|
| 20 |
+
tensor = load_tensor(path, key) # shape: [B, T, D]
|
| 21 |
+
flat = tensor.reshape(-1, tensor.shape[-1]) # [B*T, D]
|
| 22 |
+
|
| 23 |
+
sum_ = flat.sum(dim=0) # [D]
|
| 24 |
+
sum_sq = (flat**2).sum(dim=0) # [D]
|
| 25 |
+
count = flat.shape[0] # B*T
|
| 26 |
+
|
| 27 |
+
if sum_all is None:
|
| 28 |
+
sum_all = sum_
|
| 29 |
+
sum_sq_all = sum_sq
|
| 30 |
+
else:
|
| 31 |
+
sum_all += sum_
|
| 32 |
+
sum_sq_all += sum_sq
|
| 33 |
+
|
| 34 |
+
count_all += count
|
| 35 |
+
|
| 36 |
+
mean = sum_all / count_all
|
| 37 |
+
var = sum_sq_all / count_all - mean**2
|
| 38 |
+
std = torch.sqrt(torch.clamp(var, min=1e-8))
|
| 39 |
+
|
| 40 |
+
return mean, std
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
parser = argparse.ArgumentParser()
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"filelist", type=str, help="Text file with list of safetensors paths"
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument("output", type=str, help="Path to output stats.safetensors")
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--key", type=str, default="audio_z", help="Key of tensor in safetensors file"
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--max-files", type=int, default=None, help="Max number of files to process"
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--seed", type=int, default=42, help="Random seed for shuffling"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
args = parser.parse_args()
|
| 60 |
+
|
| 61 |
+
with open(args.filelist) as f:
|
| 62 |
+
files = [line.strip() for line in f if line.strip()]
|
| 63 |
+
|
| 64 |
+
if args.max_files:
|
| 65 |
+
random.seed(args.seed)
|
| 66 |
+
files = random.sample(files, k=min(args.max_files, len(files)))
|
| 67 |
+
|
| 68 |
+
mean, std = compute_global_stats(files, key=args.key)
|
| 69 |
+
|
| 70 |
+
save_file({"mean": mean, "std": std}, args.output)
|
| 71 |
+
print(f"✅ Saved to {args.output}")
|
| 72 |
+
print("Example mean/std:", mean[:5], std[:5])
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
main()
|
codec/scripts/compute_wer.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import string
|
| 4 |
+
|
| 5 |
+
from jiwer import wer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def normalize_text(text: str) -> str:
|
| 9 |
+
"""
|
| 10 |
+
Lowercase and remove punctuation from a string.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
text (str): Input string
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
str: Normalized string
|
| 17 |
+
"""
|
| 18 |
+
# Lowercase
|
| 19 |
+
text = text.lower()
|
| 20 |
+
# Remove punctuation
|
| 21 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
| 22 |
+
return text
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_transcripts(jsonl_path):
|
| 26 |
+
originals = []
|
| 27 |
+
reconstructions = []
|
| 28 |
+
with open(jsonl_path, "r", encoding="utf-8") as f:
|
| 29 |
+
for line in f:
|
| 30 |
+
data = json.loads(line)
|
| 31 |
+
originals.append(data["original_text"])
|
| 32 |
+
reconstructions.append(data["reconstructed_text"])
|
| 33 |
+
return originals, reconstructions
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main(args):
|
| 37 |
+
originals, reconstructions = map(normalize_text, load_transcripts(args.jsonl))
|
| 38 |
+
score = wer(originals, reconstructions)
|
| 39 |
+
print(f"WER: {score:.3%}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
parser = argparse.ArgumentParser()
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--jsonl", type=str, required=True, help="Path to the transcript JSONL file"
|
| 46 |
+
)
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
main(args)
|
codec/scripts/compute_wer_from_refs.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import string
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from jiwer import cer, wer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def normalize_text(text: str) -> str:
|
| 10 |
+
"""
|
| 11 |
+
Lowercase and remove punctuation from a string.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
text (str): Input string
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
str: Normalized string
|
| 18 |
+
"""
|
| 19 |
+
# Lowercase
|
| 20 |
+
text = text.lower()
|
| 21 |
+
# Remove punctuation
|
| 22 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
| 23 |
+
return text
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_jsonl_dict(path):
|
| 27 |
+
transcripts = {}
|
| 28 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 29 |
+
for line in f:
|
| 30 |
+
data = json.loads(line)
|
| 31 |
+
transcripts[Path(data["file"]).name] = data["transcript"]
|
| 32 |
+
return transcripts
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main(args):
|
| 36 |
+
ref_dict = load_jsonl_dict(args.reference)
|
| 37 |
+
hyp_dict = load_jsonl_dict(args.hypothesis)
|
| 38 |
+
|
| 39 |
+
common_files = set(ref_dict.keys()) & set(hyp_dict.keys())
|
| 40 |
+
|
| 41 |
+
if not common_files:
|
| 42 |
+
print("No common files between reference and hypothesis.")
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
refs = [normalize_text(ref_dict[f]) for f in sorted(common_files)]
|
| 46 |
+
hyps = [normalize_text(hyp_dict[f]) for f in sorted(common_files)]
|
| 47 |
+
|
| 48 |
+
cer_score = cer(refs, hyps)
|
| 49 |
+
wer_score = wer(refs, hyps)
|
| 50 |
+
print(f"CER: {cer_score:.3%}")
|
| 51 |
+
print(f"WER: {wer_score:.3%}")
|
| 52 |
+
print(f"Evaluated on {len(common_files)} files.")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
parser = argparse.ArgumentParser()
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--reference", type=str, required=True, help="Path to reference JSONL"
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--hypothesis", type=str, required=True, help="Path to hypothesis JSONL"
|
| 62 |
+
)
|
| 63 |
+
args = parser.parse_args()
|
| 64 |
+
main(args)
|
codec/scripts/download_expresso.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import soundfile as sf
|
| 2 |
+
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
|
| 5 |
+
dataset = load_dataset("ylacombe/expresso", split="train")
|
| 6 |
+
print(dataset)
|
| 7 |
+
for i, x in enumerate(dataset):
|
| 8 |
+
audio = x["audio"]
|
| 9 |
+
wav, sr = audio["array"], audio["sampling_rate"]
|
| 10 |
+
sf.write(f"expresso/org/{i}.wav", wav, sr)
|
codec/scripts/download_gigaspeech.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from random import sample
|
| 2 |
+
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
|
| 6 |
+
# dataset = load_dataset("keithito/lj_speech", split="train")
|
| 7 |
+
#dataset = load_dataset("parler-tts/mls_eng", split="train")
|
| 8 |
+
dataset = load_dataset("speechcolab/gigaspeech", "xl", split="train", token=True)
|
| 9 |
+
Is = sample(list(range(len(dataset))), k=100000)
|
| 10 |
+
print(dataset)
|
| 11 |
+
for i, I in enumerate(Is):
|
| 12 |
+
audio = dataset[I]["audio"]
|
| 13 |
+
wav, sr = audio["array"], audio["sampling_rate"]
|
| 14 |
+
sf.write(f"gigaspeech/{I}.wav", wav, sr)
|
codec/scripts/download_lj.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import soundfile as sf
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
|
| 4 |
+
dataset = load_dataset("keithito/lj_speech", split="train")
|
| 5 |
+
print(dataset)
|
| 6 |
+
for i, x in enumerate(dataset):
|
| 7 |
+
audio = x["audio"]
|
| 8 |
+
wav, sr = audio["array"], audio["sampling_rate"]
|
| 9 |
+
sf.write(f"ljspeech/{i}.wav", wav, sr)
|
codec/scripts/download_ltts.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
|
| 7 |
+
dataset = load_dataset("mythicinfinity/libritts", "clean")
|
| 8 |
+
for split in dataset.keys():
|
| 9 |
+
Path(f"libritts/{split}").mkdir(exist_ok=True)
|
| 10 |
+
for i, x in enumerate(dataset[split]):
|
| 11 |
+
# audio = x["audio"]
|
| 12 |
+
text = x["text_normalized"]
|
| 13 |
+
# wav, sr = audio["array"], audio["sampling_rate"]
|
| 14 |
+
# sf.write(f"libritts/{split}/{i}.wav", wav, sr)
|
| 15 |
+
with open(f"libritts/{split}/{i}.txt", "w") as f:
|
| 16 |
+
f.write(text)
|
codec/scripts/download_mlseng10k.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from random import sample
|
| 2 |
+
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
|
| 6 |
+
# dataset = load_dataset("keithito/lj_speech", split="train")
|
| 7 |
+
dataset = load_dataset("parler-tts/mls_eng", split="train")
|
| 8 |
+
Is = sample(list(range(len(dataset))), k=100000)
|
| 9 |
+
print(dataset)
|
| 10 |
+
for i, I in enumerate(Is):
|
| 11 |
+
audio = dataset[I]["audio"]
|
| 12 |
+
wav, sr = audio["array"], audio["sampling_rate"]
|
| 13 |
+
sf.write(f"mls10keng/{i}.wav", wav, sr)
|
codec/scripts/eval_asr.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import nemo.collections.asr as nemo_asr
|
| 6 |
+
import torch
|
| 7 |
+
import yaml
|
| 8 |
+
from jiwer import wer
|
| 9 |
+
from torchaudio import load
|
| 10 |
+
from torchaudio.functional import resample
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from zcodec.models import WavVAE, ZFlowAutoEncoder
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_config(config_path):
|
| 17 |
+
with open(config_path, "r") as f:
|
| 18 |
+
return yaml.safe_load(f)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def transcribe(audio: torch.Tensor, asr_model) -> str:
|
| 22 |
+
audio = audio.cpu().numpy(force=True)
|
| 23 |
+
with torch.inference_mode():
|
| 24 |
+
return asr_model.transcribe([audio[0]])[0].text
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main(args):
|
| 28 |
+
config = load_config(args.config)
|
| 29 |
+
|
| 30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
|
| 32 |
+
# Load models
|
| 33 |
+
wavvae = WavVAE.from_pretrained_local(config["wavvae_ckpt"]).to(device).eval()
|
| 34 |
+
zflowae = (
|
| 35 |
+
ZFlowAutoEncoder.from_pretrained_local(config["zflowae_ckpt"]).to(device).eval()
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Load ASR model
|
| 39 |
+
asr_model = nemo_asr.models.ASRModel.from_pretrained(
|
| 40 |
+
model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2")
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Read file list
|
| 44 |
+
with open(config["file_list"], "r") as f:
|
| 45 |
+
wav_files = [line.strip() for line in f if line.strip()]
|
| 46 |
+
|
| 47 |
+
results = []
|
| 48 |
+
|
| 49 |
+
for wav_path in tqdm(wav_files, desc="Processing files"):
|
| 50 |
+
wav, sr = load(wav_path)
|
| 51 |
+
wav = resample(wav, orig_freq=sr, new_freq=wavvae.sampling_rate).to(device)
|
| 52 |
+
|
| 53 |
+
with torch.inference_mode():
|
| 54 |
+
# Transcribe original
|
| 55 |
+
original_text = transcribe(wav, asr_model)
|
| 56 |
+
|
| 57 |
+
# Compress and decompress
|
| 58 |
+
z = wavvae.encode(wav)
|
| 59 |
+
zz, _ = zflowae.encode(z)
|
| 60 |
+
z_hat = zflowae.decode(
|
| 61 |
+
zz, num_steps=config.get("num_steps", 10), cfg=config.get("cfg", 2.0)
|
| 62 |
+
)
|
| 63 |
+
wav_hat = wavvae.decode(z_hat)
|
| 64 |
+
|
| 65 |
+
# Transcribe reconstructed
|
| 66 |
+
reconstructed_text = transcribe(wav_hat, asr_model)
|
| 67 |
+
|
| 68 |
+
results.append(
|
| 69 |
+
{
|
| 70 |
+
"file": wav_path,
|
| 71 |
+
"original_text": original_text,
|
| 72 |
+
"reconstructed_text": reconstructed_text,
|
| 73 |
+
}
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Save output
|
| 77 |
+
out_path = Path(config.get("output_jsonl", "transcripts.jsonl"))
|
| 78 |
+
with out_path.open("w") as f:
|
| 79 |
+
for entry in results:
|
| 80 |
+
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
| 81 |
+
|
| 82 |
+
print(f"\nSaved {len(results)} transcript pairs to {out_path}")
|
| 83 |
+
|
| 84 |
+
# Optionally compute WER
|
| 85 |
+
if args.compute_wer:
|
| 86 |
+
original_texts = [r["original_text"] for r in results]
|
| 87 |
+
reconstructed_texts = [r["reconstructed_text"] for r in results]
|
| 88 |
+
score = wer(original_texts, reconstructed_texts)
|
| 89 |
+
print(f"WER: {score:.3%}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
parser = argparse.ArgumentParser()
|
| 94 |
+
parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--compute_wer", action="store_true", help="Compute WER after decoding"
|
| 97 |
+
)
|
| 98 |
+
args = parser.parse_args()
|
| 99 |
+
|
| 100 |
+
main(args)
|
codec/scripts/eval_asr_from_filelist.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import nemo.collections.asr as nemo_asr
|
| 6 |
+
import torch
|
| 7 |
+
import yaml
|
| 8 |
+
from torchaudio import load
|
| 9 |
+
from torchaudio.functional import resample
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_config(config_path):
|
| 14 |
+
with open(config_path, "r") as f:
|
| 15 |
+
return yaml.safe_load(f)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def transcribe(audio: torch.Tensor, asr_model) -> str:
|
| 19 |
+
audio = audio.cpu().numpy(force=True)
|
| 20 |
+
with torch.inference_mode():
|
| 21 |
+
return asr_model.transcribe([audio[0]])[0].text
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def main(args):
|
| 25 |
+
config = load_config(args.config)
|
| 26 |
+
|
| 27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
+
|
| 29 |
+
# Load ASR model
|
| 30 |
+
asr_model = nemo_asr.models.ASRModel.from_pretrained(
|
| 31 |
+
model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2")
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Read file list
|
| 35 |
+
with open(config["file_list"], "r") as f:
|
| 36 |
+
wav_files = [line.strip() for line in f if line.strip()]
|
| 37 |
+
|
| 38 |
+
results = []
|
| 39 |
+
|
| 40 |
+
for wav_path in tqdm(wav_files, desc="Transcribing"):
|
| 41 |
+
wav, sr = load(wav_path)
|
| 42 |
+
wav = resample(wav, orig_freq=sr, new_freq=16000).to(device)
|
| 43 |
+
|
| 44 |
+
transcript = transcribe(wav, asr_model)
|
| 45 |
+
results.append({"file": wav_path, "transcript": transcript})
|
| 46 |
+
|
| 47 |
+
# Save output
|
| 48 |
+
out_path = Path(config.get("output_jsonl", "asr_transcripts.jsonl"))
|
| 49 |
+
with out_path.open("w") as f:
|
| 50 |
+
for entry in results:
|
| 51 |
+
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
| 52 |
+
|
| 53 |
+
print(f"\nSaved {len(results)} transcripts to {out_path}")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
parser = argparse.ArgumentParser()
|
| 58 |
+
parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
|
| 59 |
+
args = parser.parse_args()
|
| 60 |
+
main(args)
|