Mehdi Lakbar commited on
Commit
56cfa73
·
1 Parent(s): 4905304

Initial demo of Lina-speech (pardi-speech)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. codec/__init__.py +2 -0
  2. codec/__pycache__/__init__.cpython-312.pyc +0 -0
  3. codec/__pycache__/train_patchvae.cpython-312.pyc +0 -0
  4. codec/__pycache__/train_wavvae.cpython-312.pyc +0 -0
  5. codec/__pycache__/train_zflowae.cpython-312.pyc +0 -0
  6. codec/datamodules.py +249 -0
  7. codec/models/__init__.py +2 -0
  8. codec/models/__pycache__/__init__.cpython-312.pyc +0 -0
  9. codec/models/components/__init__.py +0 -0
  10. codec/models/components/__pycache__/__init__.cpython-312.pyc +0 -0
  11. codec/models/components/__pycache__/convnext.cpython-312.pyc +0 -0
  12. codec/models/components/convnext.py +221 -0
  13. codec/models/components/transformer.py +224 -0
  14. codec/models/pardi_tokenizer.py +10 -0
  15. codec/models/patchvae/__pycache__/model.cpython-312.pyc +0 -0
  16. codec/models/patchvae/__pycache__/modules.cpython-312.pyc +0 -0
  17. codec/models/patchvae/model.py +262 -0
  18. codec/models/patchvae/modules.py +396 -0
  19. codec/models/wavvae/__init__.py +0 -0
  20. codec/models/wavvae/__pycache__/__init__.cpython-312.pyc +0 -0
  21. codec/models/wavvae/__pycache__/discriminators.cpython-312.pyc +0 -0
  22. codec/models/wavvae/__pycache__/heads.cpython-312.pyc +0 -0
  23. codec/models/wavvae/__pycache__/layers.cpython-312.pyc +0 -0
  24. codec/models/wavvae/__pycache__/loss.cpython-312.pyc +0 -0
  25. codec/models/wavvae/__pycache__/model.cpython-312.pyc +0 -0
  26. codec/models/wavvae/__pycache__/modules.cpython-312.pyc +0 -0
  27. codec/models/wavvae/__pycache__/spectral_ops.cpython-312.pyc +0 -0
  28. codec/models/wavvae/dataset.py +84 -0
  29. codec/models/wavvae/discriminators.py +211 -0
  30. codec/models/wavvae/experiment.py +3 -0
  31. codec/models/wavvae/heads.py +194 -0
  32. codec/models/wavvae/helpers.py +71 -0
  33. codec/models/wavvae/layers.py +282 -0
  34. codec/models/wavvae/loss.py +142 -0
  35. codec/models/wavvae/model.py +140 -0
  36. codec/models/wavvae/modules.py +213 -0
  37. codec/models/wavvae/spectral_ops.py +192 -0
  38. codec/scripts/compare_codecs.py +441 -0
  39. codec/scripts/compare_wavvae.py +264 -0
  40. codec/scripts/compare_zcodec.py +312 -0
  41. codec/scripts/compute_stats.py +76 -0
  42. codec/scripts/compute_wer.py +48 -0
  43. codec/scripts/compute_wer_from_refs.py +64 -0
  44. codec/scripts/download_expresso.py +10 -0
  45. codec/scripts/download_gigaspeech.py +14 -0
  46. codec/scripts/download_lj.py +9 -0
  47. codec/scripts/download_ltts.py +16 -0
  48. codec/scripts/download_mlseng10k.py +13 -0
  49. codec/scripts/eval_asr.py +100 -0
  50. codec/scripts/eval_asr_from_filelist.py +60 -0
codec/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .train_patchvae import TrainPatchVAE
2
+ from .train_wavvae import TrainWavVAE
codec/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (260 Bytes). View file
 
codec/__pycache__/train_patchvae.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
codec/__pycache__/train_wavvae.cpython-312.pyc ADDED
Binary file (15 kB). View file
 
codec/__pycache__/train_zflowae.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
codec/datamodules.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import random
3
+ import time
4
+ from dataclasses import dataclass
5
+ from functools import partial
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import pytorch_lightning as ptl
10
+ import torch
11
+ import torchaudio
12
+ from safetensors.torch import safe_open
13
+ from sklearn.model_selection import train_test_split
14
+ from torch.nn.utils.rnn import pad_sequence
15
+ from torch.utils.data import DataLoader, Dataset
16
+
17
+ from datasets import load_dataset, load_from_disk
18
+
19
+
20
+ @dataclass
21
+ class WavVAEDataConfig:
22
+ filelist_path: str
23
+ sampling_rate: int
24
+ num_samples: int
25
+ batch_size: int
26
+ num_workers: int
27
+
28
+
29
+ class WavVAEDataModule(ptl.LightningDataModule):
30
+ def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig):
31
+ super().__init__()
32
+ self.train_config = train_params
33
+ self.val_config = val_params
34
+
35
+ def _get_dataloder(self, cfg: WavVAEDataConfig, train: bool):
36
+ dataset = WavVAEDataset(cfg, train=train)
37
+ dataloader = DataLoader(
38
+ dataset,
39
+ batch_size=cfg.batch_size,
40
+ num_workers=cfg.num_workers,
41
+ shuffle=train,
42
+ pin_memory=True,
43
+ )
44
+ return dataloader
45
+
46
+ def train_dataloader(self) -> DataLoader:
47
+ return self._get_dataloder(self.train_config, train=True)
48
+
49
+ def val_dataloader(self) -> DataLoader:
50
+ return self._get_dataloder(self.val_config, train=False)
51
+
52
+
53
+ class WavVAEDataset(Dataset):
54
+ def __init__(self, cfg: WavVAEDataConfig, train: bool):
55
+ with open(cfg.filelist_path) as f:
56
+ self.filelist = f.read().splitlines()
57
+ self.sampling_rate = cfg.sampling_rate
58
+ self.num_samples = cfg.num_samples
59
+ self.train = train
60
+
61
+ def __len__(self) -> int:
62
+ return len(self.filelist)
63
+
64
+ def __getitem__(self, index: int) -> torch.Tensor:
65
+ audio_path = self.filelist[index]
66
+ y, sr = torchaudio.load(audio_path)
67
+ if y.size(0) > 1:
68
+ # mix to mono
69
+ y = y.mean(dim=0, keepdim=True)
70
+ gain = np.random.uniform(-1, -6) if self.train else -3
71
+ y, _ = torchaudio.sox_effects.apply_effects_tensor(
72
+ y, sr, [["norm", f"{gain:.2f}"]]
73
+ )
74
+ if sr != self.sampling_rate:
75
+ y = torchaudio.functional.resample(
76
+ y, orig_freq=sr, new_freq=self.sampling_rate
77
+ )
78
+ if y.size(-1) < self.num_samples:
79
+ pad_length = self.num_samples - y.size(-1)
80
+ padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
81
+ y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
82
+ elif self.train:
83
+ start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
84
+ y = y[:, start : start + self.num_samples]
85
+ else:
86
+ # During validation, take always the first segment for determinism
87
+ y = y[:, : self.num_samples]
88
+
89
+ return y[0]
90
+
91
+
92
+ def pad_tensor_list_raw(
93
+ tensor_list: list[tuple[torch.Tensor, torch.Tensor]], pad_idx: int = 0
94
+ ) -> dict[str, torch.Tensor | None]:
95
+ audio, hubert_maybe = zip(*tensor_list)
96
+ audio = torch.cat(audio, dim=0)
97
+ if hubert_maybe[0] is not None:
98
+ hubert_maybe = torch.stack(hubert_maybe, dim=0)
99
+ else:
100
+ hubert_maybe = None
101
+ return {"audio_z": audio, "hubert": hubert_maybe}
102
+
103
+
104
+ class SafeTensorDataset(Dataset):
105
+ """
106
+ On __getitem__, opens the safetensor, uses get_slice() to inspect shape,
107
+ then either drops too-short files (return None) or returns a random subsequence slice.
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ file_paths: list[str],
113
+ key: str,
114
+ hubert_path: str | None = None,
115
+ hubert_key: str = "layer_9",
116
+ min_length: int = 1,
117
+ subseq_length: int | None = None,
118
+ ):
119
+ self.file_paths = file_paths
120
+ self.key = key
121
+ self.min_length = min_length
122
+ self.subseq_length = subseq_length
123
+ self.hubert_path = hubert_path
124
+ self.hubert_key = hubert_key
125
+
126
+ def __len__(self):
127
+ return len(self.file_paths)
128
+
129
+ def __getitem__(self, idx: int) -> torch.Tensor | None:
130
+ path = self.file_paths[idx]
131
+ # open file, get a slice wrapper for full tensor
132
+ with safe_open(path, framework="pt") as f:
133
+ tensor_slice = f.get_slice(self.key)
134
+ Q, N, D = tensor_slice.get_shape() # full shape [K, N]
135
+
136
+ # drop too-short
137
+ if N < self.min_length:
138
+ return None
139
+
140
+ L = self.subseq_length or N
141
+ if L < N:
142
+ # sample random start
143
+ start = torch.randint(0, max(1, N - L - 1), ()).item()
144
+ start -= start % 2
145
+ # this yields a torch.Tensor of shape [K, L]
146
+ seq = tensor_slice[:, start : start + L]
147
+ else:
148
+ # full length
149
+ start = 0
150
+ seq = tensor_slice[:, :]
151
+
152
+ if self.hubert_path is not None:
153
+ path = Path(self.hubert_path) / Path(path).name
154
+ with safe_open(path, framework="pt") as f:
155
+ tensor_slice = f.get_slice(self.hubert_key)
156
+ hubert_N, hubert_D = tensor_slice.get_shape() # full shape [K, N]
157
+ seq_hubert = tensor_slice[start // 2 : start // 2 + L // 2]
158
+ return (seq, seq_hubert)
159
+
160
+ return (seq, None)
161
+
162
+
163
+ class SafeTensorDataModule(ptl.LightningDataModule):
164
+ """
165
+ LightningDataModule using raw .safetensors file list + get_slice inside Dataset.
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ train_file_list: str,
171
+ val_file_list: str | None = None,
172
+ hubert_path: str | None = None,
173
+ key: str = "audio_z",
174
+ hubert_key: str = "layer_9",
175
+ val_split: float = 0.1,
176
+ batch_size: int = 32,
177
+ num_workers: int = 4,
178
+ shuffle: bool = True,
179
+ seed: int = 1234,
180
+ min_length: int = 1,
181
+ subseq_length: int | None = None,
182
+ ):
183
+ super().__init__()
184
+ self.train_file_list = train_file_list
185
+ self.val_file_list = val_file_list
186
+ self.hubert_path = hubert_path
187
+ self.key = key
188
+ self.val_split = val_split
189
+ self.batch_size = batch_size
190
+ self.num_workers = num_workers
191
+ self.shuffle = shuffle
192
+ self.seed = seed
193
+ self.min_length = min_length
194
+ self.subseq_length = subseq_length
195
+
196
+ def setup(self, stage=None):
197
+ with open(self.train_file_list, "r") as f:
198
+ train_paths = [line.strip() for line in f if line.strip()]
199
+ val_paths = None
200
+ if self.val_file_list is not None:
201
+ with open(self.train_file_list, "r") as f:
202
+ val_paths = [line.strip() for line in f if line.strip()]
203
+ # Split into train/val
204
+ if (
205
+ isinstance(self.val_split, float)
206
+ and 0 < self.val_split < 1
207
+ and val_paths is None
208
+ ):
209
+ train_paths, val_paths = train_test_split(
210
+ train_paths, test_size=self.val_split, random_state=self.seed
211
+ )
212
+
213
+ self.train_ds = SafeTensorDataset(
214
+ train_paths,
215
+ key=self.key,
216
+ min_length=self.min_length,
217
+ subseq_length=self.subseq_length,
218
+ hubert_path=self.hubert_path,
219
+ )
220
+ self.val_ds = SafeTensorDataset(
221
+ val_paths,
222
+ key=self.key,
223
+ min_length=self.min_length,
224
+ subseq_length=self.subseq_length,
225
+ )
226
+
227
+ def _collate_fn(
228
+ self, batch: list[torch.Tensor | None]
229
+ ) -> tuple[torch.Tensor, torch.BoolTensor]:
230
+ seqs = [s for s in batch if s is not None]
231
+ return pad_tensor_list_raw(seqs, pad_idx=0)
232
+
233
+ def train_dataloader(self):
234
+ return DataLoader(
235
+ self.train_ds,
236
+ batch_size=self.batch_size,
237
+ shuffle=self.shuffle,
238
+ num_workers=self.num_workers,
239
+ collate_fn=self._collate_fn,
240
+ )
241
+
242
+ def val_dataloader(self):
243
+ return DataLoader(
244
+ self.val_ds,
245
+ batch_size=self.batch_size,
246
+ shuffle=False,
247
+ num_workers=self.num_workers,
248
+ collate_fn=self._collate_fn,
249
+ )
codec/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .patchvae.model import PatchVAE, PatchVAEConfig
2
+ from .wavvae.model import WavVAE, WavVAEConfig
codec/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (308 Bytes). View file
 
codec/models/components/__init__.py ADDED
File without changes
codec/models/components/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (174 Bytes). View file
 
codec/models/components/__pycache__/convnext.cpython-312.pyc ADDED
Binary file (11.3 kB). View file
 
codec/models/components/convnext.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class ConvNeXtBlock(nn.Module):
6
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
7
+
8
+ Args:
9
+ dim (int): Number of input channels.
10
+ intermediate_dim (int): Dimensionality of the intermediate layer.
11
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
12
+ Defaults to None.
13
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
14
+ None means non-conditional LayerNorm. Defaults to None.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ intermediate_dim: int | None = None,
21
+ layer_scale_init_value: float = 0.0,
22
+ elementwise_affine_ln: bool = True,
23
+ is_causal: bool = False,
24
+ ):
25
+ super().__init__()
26
+ intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3
27
+ self.dwconv = nn.Conv1d(
28
+ dim, dim, kernel_size=7, padding=0 if is_causal else 3, groups=dim
29
+ ) # depthwise conv
30
+ self.norm = nn.LayerNorm(
31
+ dim, eps=1e-6, elementwise_affine=elementwise_affine_ln
32
+ )
33
+ self.pwconv1 = nn.Linear(
34
+ dim, intermediate_dim
35
+ ) # pointwise/1x1 convs, implemented with linear layers
36
+ self.act = nn.GELU()
37
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
38
+ self.gamma = (
39
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
40
+ if layer_scale_init_value > 0
41
+ else None
42
+ )
43
+ self.is_causal = is_causal
44
+
45
+ def forward(
46
+ self,
47
+ x: torch.Tensor,
48
+ scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
49
+ gate: torch.Tensor | None = None,
50
+ ) -> torch.Tensor:
51
+ residual = x
52
+ if self.is_causal:
53
+ x = torch.nn.functional.pad(x, (6, 0))
54
+ x = self.dwconv(x)
55
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
56
+ x = self.norm(x)
57
+ if scale_shift is not None:
58
+ scale, shift = scale_shift
59
+ x = x * scale[:, None] + shift[:, None]
60
+ x = self.pwconv1(x)
61
+ x = self.act(x)
62
+ x = self.pwconv2(x)
63
+ if self.gamma is not None:
64
+ x = self.gamma * x
65
+ if gate is not None:
66
+ x = gate[:, None] * x
67
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
68
+
69
+ x = residual + x
70
+ return x
71
+
72
+
73
+ class ConvNextNet(nn.Module):
74
+ def __init__(self, n_layers, dim, intermediate_dim: int | None = None):
75
+ super().__init__()
76
+ self.net = nn.Sequential(
77
+ *[
78
+ ConvNeXtBlock(
79
+ dim,
80
+ intermediate_dim,
81
+ )
82
+ for _ in range(n_layers)
83
+ ]
84
+ )
85
+
86
+ def forward(self, x):
87
+ return self.net(x)
88
+
89
+
90
+ class ConvNextPatchEncoder(nn.Module):
91
+ def __init__(
92
+ self,
93
+ patch_sizes: list[int],
94
+ n_layers_per_patch: int,
95
+ patch_expansion_factor: float = 1.5,
96
+ is_decoder: bool = False,
97
+ ):
98
+ super().__init__()
99
+ patch_to_dim = []
100
+ convnext = []
101
+ for i, patch_size in enumerate(patch_sizes):
102
+ in_dim = int((patch_expansion_factor if i > 0 else 1.0) * patch_size)
103
+ out_dim = int(patch_expansion_factor * patch_size)
104
+ if is_decoder:
105
+ in_dim, out_dim = out_dim, in_dim
106
+ patch_to_dim.append(
107
+ nn.Linear(
108
+ in_dim,
109
+ out_dim,
110
+ )
111
+ )
112
+ convnext += [
113
+ nn.Sequential(
114
+ *[
115
+ ConvNeXtBlock(int(patch_size * patch_expansion_factor))
116
+ for _ in range(n_layers_per_patch)
117
+ ]
118
+ )
119
+ ]
120
+ self.is_decoder = is_decoder
121
+ self.patch_sizes = patch_sizes
122
+ self.patch_expansion_factor = patch_expansion_factor
123
+ self.patch_to_dim = nn.ModuleList(patch_to_dim)
124
+ self.convnext = nn.ModuleList(convnext)
125
+
126
+ def forward(self, x):
127
+ if self.is_decoder:
128
+ for i, patch_size in reversed(list(enumerate(self.patch_sizes))):
129
+ B, P, N = x.shape
130
+ patch_expansion_factor_maybe = (
131
+ self.patch_expansion_factor if i > 0 else 1.0
132
+ )
133
+ x = x.reshape(B, int(patch_size * self.patch_expansion_factor), -1)
134
+ x = self.convnext[i](x)
135
+ x = self.patch_to_dim[i](x.transpose(1, 2)).transpose(1, 2)
136
+ else:
137
+ for i, patch_size in enumerate(self.patch_sizes):
138
+ B, P, N = x.shape
139
+ patch_expansion_factor_maybe = (
140
+ self.patch_expansion_factor if i > 0 else 1.0
141
+ )
142
+ x = x.reshape(B, int(patch_size * patch_expansion_factor_maybe), -1)
143
+ x = self.patch_to_dim[i](x.transpose(1, 2)).transpose(1, 2)
144
+ x = self.convnext[i](x)
145
+ return x
146
+
147
+
148
+ class ConvNextEncoder(nn.Module):
149
+ def __init__(
150
+ self,
151
+ in_dim: int,
152
+ dim: int,
153
+ n_layers: int,
154
+ intermediate_dim: int | None = None,
155
+ stride: int = 1,
156
+ ):
157
+ super().__init__()
158
+ self.in_proj = nn.Linear(in_dim, dim)
159
+ if stride > 1:
160
+ self.stride = nn.Conv1d(
161
+ in_channels=dim,
162
+ out_channels=dim,
163
+ kernel_size=(stride * 2) + 1,
164
+ stride=stride,
165
+ padding=stride // 2,
166
+ )
167
+ else:
168
+ self.stride = nn.Identity()
169
+ self.net = ConvNextNet(n_layers, dim, intermediate_dim)
170
+
171
+ def forward(self, x):
172
+ x = self.in_proj(x.transpose(1, 2)).transpose(1, 2)
173
+ x = self.stride(x)
174
+ return self.net(x)
175
+
176
+
177
+ class ConvNextDecoder(nn.Module):
178
+ def __init__(
179
+ self,
180
+ out_dim: int,
181
+ dim: int,
182
+ n_layers: int,
183
+ intermediate_dim: int | None = None,
184
+ stride: int = 1,
185
+ stride_position: str = "before",
186
+ ):
187
+ super().__init__()
188
+ self.out_proj = nn.Linear(dim, out_dim)
189
+ if stride > 1:
190
+ self.stride = nn.ConvTranspose1d(
191
+ in_channels=dim,
192
+ out_channels=dim,
193
+ kernel_size=(stride * 2) + 1,
194
+ stride=stride,
195
+ padding=stride // 2,
196
+ output_padding=stride // 2,
197
+ )
198
+ else:
199
+ self.stride = nn.Identity()
200
+ self.stride_position = stride_position
201
+
202
+ self.net = ConvNextNet(n_layers, dim, intermediate_dim)
203
+
204
+ def forward(self, x):
205
+ if self.stride_position == "before":
206
+ x = self.stride(x)
207
+ x = self.net(x)
208
+ if self.stride_position == "after":
209
+ x = self.stride(x)
210
+ return self.out_proj(x.transpose(1, 2)).transpose(1, 2)
211
+
212
+
213
+ class SwiGLU(nn.Module):
214
+ def __init__(self, d_model: int, ffn_expansion_factor: int = 4):
215
+ super().__init__()
216
+ self.p_in = nn.Linear(d_model, (d_model * ffn_expansion_factor // 3) * 2)
217
+ self.p_out = nn.Linear(d_model * ffn_expansion_factor // 3, d_model)
218
+
219
+ def forward(self, x):
220
+ gate, x = self.p_in(x).chunk(2, dim=-1)
221
+ return self.p_out(nn.functional.silu(gate) * x)
codec/models/components/transformer.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
8
+
9
+
10
+ class LocalSelfAttention(nn.Module):
11
+ def __init__(
12
+ self,
13
+ dim: int,
14
+ heads: int,
15
+ window_len: int = 32,
16
+ rotary: bool = True,
17
+ is_causal: bool = False,
18
+ ):
19
+ super().__init__()
20
+ self.heads = heads
21
+ assert dim % heads == 0, "dim must be divisible by heads"
22
+ self.qkv = nn.Linear(dim, 3 * dim)
23
+ self.o = nn.Linear(dim, dim)
24
+ self.rotary = RotaryEmbedding((dim // heads) // 2) if rotary else None
25
+ self.is_causal = is_causal
26
+ self.window_len = window_len
27
+
28
+ def forward(
29
+ self,
30
+ x: torch.Tensor,
31
+ mask: Optional[torch.Tensor] = None,
32
+ pos: Optional[torch.Tensor] = None,
33
+ cache: Optional[Dict[int, torch.Tensor]] = None,
34
+ layer_idx: Optional[int] = None,
35
+ time_step: int = 0,
36
+ ) -> torch.Tensor:
37
+ # x: (batch, seq_len, dim)
38
+ b, n, dim = x.shape
39
+ b, t_len, hd = x.shape
40
+ pad_len = (self.window_len - t_len % self.window_len) % self.window_len
41
+ padded_x = torch.nn.functional.pad(x, (0, 0, 0, pad_len)) # pad on time dim
42
+ mask = torch.ones(t_len, dtype=torch.bool, device=x.device)
43
+ mask = torch.nn.functional.pad(
44
+ mask, (0, pad_len), value=False
45
+ ) # False = masked
46
+ mask = mask.expand(b, -1) # [b, padded_len]
47
+ mask = rearrange(mask, "b (w n) -> b n 1 1 w", w=self.window_len)
48
+ qkv = self.qkv(padded_x).chunk(3, dim=-1)
49
+ q, k, v = [
50
+ rearrange(t, "b (w n) (h d) -> b n h w d", h=self.heads, w=self.window_len)
51
+ for t in qkv
52
+ ]
53
+ if cache is not None:
54
+ assert layer_idx is not None, "layer_idx must be set when using cache"
55
+ cache[layer_idx]["k"] = torch.cat([cache[layer_idx]["k"], k], dim=2)
56
+ cache[layer_idx]["v"] = torch.cat([cache[layer_idx]["v"], v], dim=2)
57
+ k, v = cache[layer_idx]["k"], cache[layer_idx]["v"]
58
+
59
+ # apply rotary embeddings
60
+ if self.rotary is not None:
61
+ if pos is not None:
62
+ rot = self.rotary(pos) # (b,1,n,head_dim)
63
+ q = apply_rotary_emb(rot, q)
64
+ k = apply_rotary_emb(rot, k)
65
+ else:
66
+ q = self.rotary.rotate_queries_or_keys(q, offset=time_step)
67
+ k = self.rotary.rotate_queries_or_keys(k, offset=time_step)
68
+
69
+ # scaled dot-product attention
70
+ y = F.scaled_dot_product_attention(
71
+ q,
72
+ k,
73
+ v,
74
+ attn_mask=None if self.is_causal else mask,
75
+ is_causal=self.is_causal,
76
+ )
77
+ y = rearrange(y, "b n h w d -> b (w n) (h d)")
78
+ y = self.o(y)
79
+ y = y[:, :t_len]
80
+ return y
81
+
82
+
83
+ class SelfAttention(nn.Module):
84
+ def __init__(
85
+ self, dim: int, heads: int, rotary: bool = True, is_causal: bool = False
86
+ ):
87
+ super().__init__()
88
+ self.heads = heads
89
+ assert dim % heads == 0, "dim must be divisible by heads"
90
+ self.qkv = nn.Linear(dim, 3 * dim)
91
+ self.o = nn.Linear(dim, dim)
92
+ self.rotary = RotaryEmbedding((dim // heads) // 2) if rotary else None
93
+ self.is_causal = is_causal
94
+
95
+ def forward(
96
+ self,
97
+ x: torch.Tensor,
98
+ mask: Optional[torch.Tensor] = None,
99
+ pos: Optional[torch.Tensor] = None,
100
+ cache: Optional[Dict[int, torch.Tensor]] = None,
101
+ layer_idx: Optional[int] = None,
102
+ time_step: int = 0,
103
+ ) -> torch.Tensor:
104
+ # x: (batch, seq_len, dim)
105
+ b, n, dim = x.shape
106
+ b, t_len, hd = x.shape
107
+ pad_len = (32 - t_len % 32) % 32
108
+ padded_x = torch.nn.functional.pad(x, (0, 0, 0, pad_len)) # pad on time dim
109
+ mask = torch.ones(t_len, dtype=torch.bool, device=x.device)
110
+ mask = torch.nn.functional.pad(
111
+ mask, (0, pad_len), value=False
112
+ ) # False = masked
113
+ mask = mask.expand(b, -1) # [b, padded_len]
114
+ mask = rearrange(mask, "b (w n) -> b n 1 1 w", w=32)
115
+ qkv = self.qkv(padded_x).chunk(3, dim=-1)
116
+ q, k, v = [
117
+ rearrange(t, "b (w n) (h d) -> b n h w d", h=self.heads, w=32) for t in qkv
118
+ ]
119
+ # caching for fast autoregressive
120
+ if cache is not None:
121
+ assert layer_idx is not None, "layer_idx must be set when using cache"
122
+ # append new keys/values
123
+ cache[layer_idx]["k"] = torch.cat([cache[layer_idx]["k"], k], dim=2)
124
+ cache[layer_idx]["v"] = torch.cat([cache[layer_idx]["v"], v], dim=2)
125
+ k, v = cache[layer_idx]["k"], cache[layer_idx]["v"]
126
+
127
+ # apply rotary embeddings
128
+ if self.rotary is not None:
129
+ if pos is not None:
130
+ rot = self.rotary(pos) # .unsqueeze(1) # (b,1,n,head_dim)
131
+ q = apply_rotary_emb(rot, q)
132
+ k = apply_rotary_emb(rot, k)
133
+ else:
134
+ q = self.rotary.rotate_queries_or_keys(q, offset=time_step)
135
+ k = self.rotary.rotate_queries_or_keys(k, offset=time_step)
136
+
137
+ # scaled dot-product attention
138
+ y = F.scaled_dot_product_attention(
139
+ q,
140
+ k,
141
+ v,
142
+ attn_mask=None if self.is_causal else mask,
143
+ is_causal=self.is_causal,
144
+ )
145
+ y = rearrange(y, "b n h w d -> b (w n) (h d)")
146
+ y = self.o(y)
147
+ y = y[:, :t_len]
148
+ return y
149
+
150
+
151
+ class SwiGLU(nn.Module):
152
+ def __init__(self, d_model: int):
153
+ super().__init__()
154
+ hidden = d_model * 4 // 3
155
+ self.p_in = nn.Linear(d_model, hidden * 2)
156
+ self.p_out = nn.Linear(hidden, d_model)
157
+
158
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
159
+ gate, data = self.p_in(x).chunk(2, dim=-1)
160
+ return self.p_out(F.silu(gate) * data)
161
+
162
+
163
+ class TransformerBlock(nn.Module):
164
+ """
165
+ Transformer block using custom SelfAttention and SwiGLU FFN.
166
+
167
+ Args:
168
+ dim: embedding dimension
169
+ heads: number of attention heads
170
+ rotary: whether to use rotary embeddings
171
+ is_causal: whether to apply causal masking
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ dim: int,
177
+ head_size: int,
178
+ rotary: bool = True,
179
+ is_causal: bool = False,
180
+ elementwise_affine_ln: bool = True,
181
+ ):
182
+ super().__init__()
183
+ assert dim % head_size == 0
184
+ heads = dim // head_size
185
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine_ln)
186
+ self.attn = LocalSelfAttention(dim, heads, rotary=rotary, is_causal=is_causal)
187
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine_ln)
188
+ self.ffn = SwiGLU(dim)
189
+
190
+ def forward(
191
+ self,
192
+ x: torch.Tensor,
193
+ mask: Optional[torch.Tensor] = None,
194
+ pos: Optional[torch.Tensor] = None,
195
+ cache: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
196
+ layer_idx: Optional[int] = None,
197
+ scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
198
+ gate: torch.Tensor = None,
199
+ time_step: int = 0,
200
+ ) -> torch.Tensor:
201
+ # Self-attention block
202
+ norm1_x = self.norm1(x)
203
+ if scale_shift is not None:
204
+ scale, shift = scale_shift
205
+ norm1_x = norm1_x * scale[:, None] + shift[:, None]
206
+
207
+ attn_out = self.attn(
208
+ norm1_x,
209
+ mask=mask,
210
+ pos=pos,
211
+ cache=cache,
212
+ layer_idx=layer_idx,
213
+ time_step=time_step,
214
+ )
215
+ x = x + attn_out
216
+
217
+ norm2_x = self.norm2(x)
218
+ if gate is not None:
219
+ norm2_x = gate[:, None] * norm2_x
220
+
221
+ # Feedforward block
222
+ ffn_out = self.ffn(norm2_x)
223
+ x = x + ffn_out
224
+ return x
codec/models/pardi_tokenizer.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from zcodec.models import WavVAE, ZFlowAutoEncoder
4
+ from zcodec.models.wavvae.model import WavVAEConfig
5
+ from zcodec.models.zflowae.model import ZFlowAutoEncoderConfig
6
+
7
+
8
+ class PardiTokenizer(nn.Module):
9
+ def __init__(self, wavvae_cfg: WavVAEConfig, zflowae_cfg: ZFlowAutoEncoderConfig):
10
+
codec/models/patchvae/__pycache__/model.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
codec/models/patchvae/__pycache__/modules.cpython-312.pyc ADDED
Binary file (20.8 kB). View file
 
codec/models/patchvae/model.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ import sys
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ from safetensors.torch import load_file
11
+ from torch import nn
12
+ from torchdyn.core import NeuralODE
13
+
14
+ from .modules import AdaLNFlowPredictor, AutoEncoder
15
+
16
+
17
+ @contextmanager
18
+ def suppress_stdout():
19
+ original_stdout = sys.stdout
20
+ try:
21
+ sys.stdout = open(os.devnull, "w")
22
+ yield
23
+ finally:
24
+ sys.stdout.close()
25
+ sys.stdout = original_stdout
26
+
27
+
28
+ def cosine_schedule_with_warmup(warmup_steps, total_steps, start_lr, end_lr):
29
+ def lr_lambda(step):
30
+ if step < warmup_steps:
31
+ return step / max(1, warmup_steps)
32
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
33
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
34
+ return (start_lr - end_lr) * cosine_decay / start_lr + end_lr / start_lr
35
+
36
+ return lr_lambda
37
+
38
+
39
+ @dataclass
40
+ class PatchVAEConfig:
41
+ latent_dim: int
42
+ hidden_dim: int
43
+ latent_scaling: tuple[list[float], list[float]] | None
44
+ flow_factory: str
45
+ num_flow_layers: int
46
+ autoencoder_factory: str
47
+ num_autoencoder_layers: int
48
+ convnextformer_num_conv_per_transformer: int = 3
49
+ wavvae_path: str | None = None
50
+ fsq_levels: list[int] | None = None
51
+ bottleneck_size: int | None = None
52
+ latent_stride: int = 2
53
+ vae: bool = False
54
+ causal_transformer: bool = False
55
+ cond_dim: int | None = None
56
+ is_causal: bool = False
57
+
58
+
59
+ class PatchVAE(nn.Module):
60
+ def __init__(self, cfg: PatchVAEConfig):
61
+ super().__init__()
62
+ self.flow_net = AdaLNFlowPredictor(
63
+ feat_dim=cfg.latent_dim * cfg.latent_stride,
64
+ dim=cfg.hidden_dim,
65
+ n_layer=cfg.num_flow_layers,
66
+ layer_factory=cfg.flow_factory,
67
+ cond_dim=cfg.cond_dim,
68
+ is_causal=cfg.is_causal,
69
+ )
70
+ self.autoencoder = AutoEncoder(
71
+ cfg.latent_dim * cfg.latent_stride,
72
+ cfg.hidden_dim,
73
+ cfg.num_autoencoder_layers,
74
+ cfg.autoencoder_factory,
75
+ out_dim=cfg.cond_dim,
76
+ vae=cfg.vae,
77
+ bottleneck_size=cfg.bottleneck_size,
78
+ convnextformer_num_conv_per_transformer=cfg.convnextformer_num_conv_per_transformer,
79
+ is_causal=cfg.is_causal,
80
+ )
81
+ if cfg.latent_scaling is not None:
82
+ mean, std = cfg.latent_scaling
83
+ self.register_buffer("mean_latent_scaling", torch.tensor(mean))
84
+ self.register_buffer("std_latent_scaling", torch.tensor(std))
85
+ else:
86
+ self.mean_latent_scaling = None
87
+ self.std_latent_scaling = None
88
+
89
+ self.latent_stride = cfg.latent_stride
90
+ self.latent_dim = cfg.latent_dim
91
+ self.wavvae = None
92
+
93
+ @classmethod
94
+ def from_pretrained(
95
+ cls,
96
+ pretrained_model_name_or_path: str,
97
+ map_location: str = "cpu",
98
+ ):
99
+ if Path(pretrained_model_name_or_path).exists():
100
+ path = pretrained_model_name_or_path
101
+ else:
102
+ from huggingface_hub import snapshot_download
103
+
104
+ path = snapshot_download(pretrained_model_name_or_path)
105
+
106
+ with open(Path(path) / "config.json", "r") as f:
107
+ config = json.load(f)
108
+ config = PatchVAEConfig(**config)
109
+ model = cls(config).to(map_location)
110
+ state_dict = load_file(
111
+ Path(path) / "model.st",
112
+ device=map_location,
113
+ )
114
+ model.load_state_dict(state_dict, assign=True)
115
+ if config.wavvae_path is not None:
116
+ from .. import WavVAE
117
+
118
+ model.wavvae = WavVAE.from_pretrained(config.wavvae_path).to(map_location)
119
+ else:
120
+ model.wavvae = None
121
+
122
+ return model
123
+
124
+ def wavvae_from_pretrained(
125
+ self,
126
+ pretrained_model_name_or_path: str,
127
+ *args,
128
+ **kwargs,
129
+ ):
130
+ from .. import WavVAE
131
+
132
+ self.wavvae = WavVAE.from_pretrained(
133
+ pretrained_model_name_or_path,
134
+ *args,
135
+ **kwargs,
136
+ )
137
+
138
+ def encode(self, wav: torch.Tensor):
139
+ assert self.wavvae is not None, (
140
+ "please provide WavVAE model to encode from waveform"
141
+ )
142
+ z = self.wavvae.encode(wav)
143
+ zz = self.encode_patch(z)
144
+ return zz
145
+
146
+ def decode(self, patchvae_latent: torch.Tensor, **kwargs):
147
+ assert self.wavvae is not None, (
148
+ "please provide WavVAE model to decode to waveform"
149
+ )
150
+ z = self.decode_patch(patchvae_latent, **kwargs)
151
+ wav = self.wavvae.decode(z)
152
+ return wav
153
+
154
+ def normalize_z(self, z: torch.Tensor):
155
+ if self.mean_latent_scaling is not None:
156
+ z = (z - self.mean_latent_scaling) / self.std_latent_scaling
157
+ return z
158
+
159
+ def denormalize_z(self, z: torch.Tensor):
160
+ if self.std_latent_scaling is not None:
161
+ z = z * self.std_latent_scaling + self.mean_latent_scaling
162
+ return z
163
+
164
+ def encode_patch(self, z: torch.Tensor, deterministic: bool = False):
165
+ B, T, D = z.shape
166
+ z = self.normalize_z(z)
167
+ if self.latent_stride > 1:
168
+ z = z[:, : T - T % self.latent_stride]
169
+ z = z.reshape(B, T // self.latent_stride, D * self.latent_stride)
170
+ return self.autoencoder.encode(z, deterministic=deterministic)
171
+
172
+ def decode_patch(
173
+ self,
174
+ latent: torch.Tensor,
175
+ cfg: float = 2.0,
176
+ num_steps: int = 15,
177
+ solver: str = "euler",
178
+ sensitivity: str = "adjoint",
179
+ temperature: float = 1.0,
180
+ **kwargs,
181
+ ):
182
+ with torch.no_grad():
183
+ z_cond = self.autoencoder.decode(latent).transpose(1, 2)
184
+ if cfg == 1.0:
185
+
186
+ def solver_fn(t, Xt, *args, **kwargs):
187
+ flow = self.flow_net(Xt, z_cond, t.unsqueeze(0))
188
+ return flow
189
+ else:
190
+ z_cond_uncond = torch.cat((z_cond, torch.zeros_like(z_cond)), dim=0)
191
+
192
+ def solver_fn(t, Xt, *args, **kwargs):
193
+ flow = self.flow_net(
194
+ Xt.repeat(2, 1, 1), z_cond_uncond, t.unsqueeze(0)
195
+ )
196
+ cond, uncond = flow.chunk(2, dim=0)
197
+
198
+ return uncond + cfg * (cond - uncond)
199
+
200
+ with suppress_stdout():
201
+ node_ = NeuralODE(
202
+ solver_fn,
203
+ solver=solver,
204
+ sensitivity=sensitivity,
205
+ **kwargs,
206
+ )
207
+ t_span = torch.linspace(0, 1, num_steps + 1, device=z_cond.device)
208
+ patch_dim = self.latent_dim * self.latent_stride
209
+ x0 = torch.randn(
210
+ z_cond.shape[0],
211
+ patch_dim,
212
+ z_cond.shape[2],
213
+ device=z_cond.device,
214
+ )
215
+ traj = node_.trajectory(
216
+ x0 * temperature,
217
+ t_span=t_span,
218
+ )
219
+
220
+ y_hat = traj[-1]
221
+ y_hat = y_hat.transpose(1, 2)
222
+ B, T, D = y_hat.shape
223
+ y_hat = y_hat.reshape(B, T * self.latent_stride, D // self.latent_stride)
224
+ y_hat = self.denormalize_z(y_hat)
225
+ return y_hat
226
+
227
+ def forward(
228
+ self,
229
+ z: torch.Tensor,
230
+ t: torch.Tensor,
231
+ drop_cond_rate: float = 0.0,
232
+ drop_vae_rate: float = 0.0,
233
+ sigma: float = 1e-4,
234
+ ):
235
+ z = self.normalize_z(z)
236
+ B, T, D = z.shape
237
+ if self.latent_stride > 1:
238
+ z = z.reshape(B, T // self.latent_stride, D * self.latent_stride)
239
+
240
+ prior, ae_loss = self.autoencoder(z, drop_vae_rate=drop_vae_rate)
241
+
242
+ if drop_cond_rate > 0.0:
243
+ to_drop = torch.rand(prior.shape[0], device=prior.device) < drop_cond_rate
244
+ prior[to_drop] = 0.0
245
+
246
+ x0 = torch.randn_like(z)
247
+ x1 = z
248
+
249
+ flow_target = x1 - (1 - sigma) * x0
250
+
251
+ alpha = (1 - (1 - sigma) * t).view(-1, 1, 1)
252
+ xt = alpha * x0 + t.view(-1, 1, 1) * x1
253
+
254
+ pred = self.flow_net(
255
+ xt.transpose(1, 2),
256
+ prior.transpose(1, 2),
257
+ t,
258
+ )
259
+
260
+ flow_loss = nn.functional.mse_loss(flow_target.transpose(1, 2), pred)
261
+
262
+ return flow_loss, ae_loss, prior
codec/models/patchvae/modules.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import random
2
+ from typing import Literal
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from vector_quantize_pytorch import FSQ
8
+
9
+ from zcodec.models.components.transformer import TransformerBlock
10
+
11
+
12
+ class AdaLayerNormScale(nn.Module):
13
+ def __init__(self, dim: int):
14
+ super().__init__()
15
+ self.linear = nn.Linear(dim, dim * 3)
16
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False)
17
+
18
+ def forward(self, x, c):
19
+ x = self.norm(x)
20
+ scale, bias, gate = self.linear(F.silu(c)).chunk(3, dim=1)
21
+ shape = x.shape[0] + [1] * (x.dim() - 2) + x.shape[-1]
22
+ scale, bias, gate = map(lambda x: x.view(*shape), (scale, bias, gate))
23
+ x = x * (1 + scale) + bias
24
+ return x, gate
25
+
26
+
27
+ class GaussianFourierTimeEmbedding(nn.Module):
28
+ def __init__(self, dim: int):
29
+ super().__init__()
30
+ self.weight = nn.Parameter(torch.randn(dim), requires_grad=False)
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ x = x[:, None] * self.weight[None, :] * 2 * torch.pi
34
+ x = torch.cat((torch.sin(x), torch.cos(x)), dim=1)
35
+ return x
36
+
37
+
38
+ LAYER_FACTORIES = {}
39
+
40
+
41
+ def register_flow_layer_factory(name):
42
+ def decorator(fn):
43
+ LAYER_FACTORIES[name] = fn
44
+ return fn
45
+
46
+ return decorator
47
+
48
+
49
+ @register_flow_layer_factory("convnext")
50
+ def SimpleConvNextFactory(dim: int, i: int, n_layer: int, is_causal: bool = False):
51
+ return ConvNeXtBlock(dim, elementwise_affine_ln=False, is_causal=is_causal)
52
+
53
+
54
+ @register_flow_layer_factory("mlp")
55
+ def MLP(dim: int, i: int, n_layer: int, is_causal: bool = False):
56
+ return AdaLNMLP(dim)
57
+
58
+
59
+ @register_flow_layer_factory("sa_transformer")
60
+ def SelfAttentionTransformer(dim: int, i: int, n_layer: int, is_causal: bool = False):
61
+ return TransformerBlock(dim, 64, elementwise_affine_ln=False, is_causal=is_causal)
62
+
63
+
64
+ def init_weights(m: nn.Module):
65
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
66
+ nn.init.trunc_normal_(m.weight, std=0.02)
67
+ nn.init.constant_(m.bias, 0)
68
+
69
+
70
+ def init_adaln_weights(m: nn.Module):
71
+ nn.init.trunc_normal_(m.weight, std=0.02)
72
+ nn.init.zeros_(m.bias)
73
+
74
+
75
+ def modulate(x, scale, shift):
76
+ return x * (1 + scale[:, None]) + shift[:, None]
77
+
78
+
79
+ class AdaLNFlowPredictor(nn.Module):
80
+ def __init__(
81
+ self,
82
+ feat_dim: int,
83
+ dim: int,
84
+ n_layer: int,
85
+ layer_factory: str,
86
+ cond_dim: int | None = None,
87
+ is_causal: bool = False,
88
+ ):
89
+ super().__init__()
90
+
91
+ layer_factory = LAYER_FACTORIES[layer_factory]
92
+ self.layers = nn.ModuleList(
93
+ [
94
+ layer_factory(dim, i, n_layer, is_causal=is_causal)
95
+ for i in range(n_layer)
96
+ ]
97
+ )
98
+ if cond_dim is None:
99
+ cond_dim = feat_dim
100
+ self.initial_proj = nn.Linear(feat_dim + cond_dim, dim)
101
+ self.adaln_proj = nn.ModuleList([nn.Linear(dim, dim * 3) for _ in self.layers])
102
+ self.final_adaln_proj = nn.Linear(dim, dim * 2)
103
+ self.out_proj = nn.Linear(dim, feat_dim)
104
+ self.final_norm = nn.LayerNorm(dim, elementwise_affine=False)
105
+ self.time_emb = GaussianFourierTimeEmbedding(dim // 2)
106
+
107
+ self.apply(init_weights)
108
+ for l in self.adaln_proj:
109
+ init_adaln_weights(l)
110
+ init_adaln_weights(self.final_adaln_proj)
111
+
112
+ def forward(
113
+ self,
114
+ x_t: torch.Tensor,
115
+ x_mu: torch.Tensor,
116
+ t: torch.Tensor,
117
+ ):
118
+ x_t, x_mu = map(lambda x: x.transpose(1, 2), (x_t, x_mu))
119
+ x = self.initial_proj(torch.cat((x_t, x_mu), dim=-1)).transpose(1, 2)
120
+
121
+ t_emb = self.time_emb(t)
122
+
123
+ for i, (l, adaln) in enumerate(zip(self.layers, self.adaln_proj)):
124
+ scale, shift, gate = F.silu(adaln(t_emb)).chunk(3, dim=1)
125
+ x = l(x, scale_shift=(scale, shift), gate=gate)
126
+
127
+ scale, shift = F.silu(self.final_adaln_proj(t_emb)).chunk(2, dim=1)
128
+ x = self.final_norm(x.transpose(1, 2))
129
+ x = modulate(x, scale, shift)
130
+
131
+ x = self.out_proj(x).transpose(1, 2)
132
+
133
+ return x
134
+
135
+
136
+ class AdaLNMLP(nn.Module):
137
+ def __init__(self, hidden_dim):
138
+ super().__init__()
139
+ self.hidden_dim = hidden_dim
140
+
141
+ self.in_ln = nn.LayerNorm(hidden_dim, eps=1e-6, elementwise_affine=False)
142
+ self.mlp = nn.Sequential(
143
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
144
+ nn.SiLU(),
145
+ nn.Linear(hidden_dim, hidden_dim, bias=True),
146
+ )
147
+
148
+ self.adaLN_modulation = nn.Sequential(
149
+ nn.SiLU(), nn.Linear(hidden_dim, 4 * hidden_dim, bias=True)
150
+ )
151
+
152
+ def forward(self, x, scale_shift, gate):
153
+ x = x.transpose(-1, -2)
154
+ h = modulate(self.in_ln(x), *scale_shift)
155
+ h = self.mlp(h)
156
+ return (x + gate[:, None] * h).transpose(-1, -2)
157
+
158
+
159
+ class ConvNeXtBlock(nn.Module):
160
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
161
+
162
+ Args:
163
+ dim (int): Number of input channels.
164
+ intermediate_dim (int): Dimensionality of the intermediate layer.
165
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
166
+ Defaults to None.
167
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
168
+ None means non-conditional LayerNorm. Defaults to None.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ dim: int,
174
+ intermediate_dim: int | None = None,
175
+ layer_scale_init_value: float = 0.0,
176
+ elementwise_affine_ln: bool = True,
177
+ is_causal: bool = False,
178
+ ):
179
+ super().__init__()
180
+ intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3
181
+ self.dwconv = nn.Conv1d(
182
+ dim, dim, kernel_size=7, padding=0 if is_causal else 3, groups=dim
183
+ ) # depthwise conv
184
+ self.norm = nn.LayerNorm(
185
+ dim, eps=1e-6, elementwise_affine=elementwise_affine_ln
186
+ )
187
+ self.pwconv1 = nn.Linear(
188
+ dim, intermediate_dim
189
+ ) # pointwise/1x1 convs, implemented with linear layers
190
+ self.act = nn.GELU()
191
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
192
+ self.gamma = (
193
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
194
+ if layer_scale_init_value > 0
195
+ else None
196
+ )
197
+ self.is_causal = is_causal
198
+
199
+ def forward(
200
+ self,
201
+ x: torch.Tensor,
202
+ scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
203
+ gate: torch.Tensor | None = None,
204
+ ) -> torch.Tensor:
205
+ residual = x
206
+ if self.is_causal:
207
+ x = torch.nn.functional.pad(x, (6, 0))
208
+ x = self.dwconv(x)
209
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
210
+ x = self.norm(x)
211
+ if scale_shift is not None:
212
+ scale, shift = scale_shift
213
+ x = x * scale[:, None] + shift[:, None]
214
+ x = self.pwconv1(x)
215
+ x = self.act(x)
216
+ x = self.pwconv2(x)
217
+ if self.gamma is not None:
218
+ x = self.gamma * x
219
+ if gate is not None:
220
+ x = gate[:, None] * x
221
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
222
+
223
+ x = residual + x
224
+ return x
225
+
226
+
227
+ class ConvNextNet(nn.Module):
228
+ def __init__(
229
+ self,
230
+ dim: int,
231
+ n_layers: int,
232
+ intermediate_dim: int | None = None,
233
+ is_causal: bool = False,
234
+ ):
235
+ super().__init__()
236
+ self.net = nn.Sequential(
237
+ *[
238
+ ConvNeXtBlock(dim, intermediate_dim, is_causal=is_causal)
239
+ for _ in range(n_layers)
240
+ ]
241
+ )
242
+
243
+ def forward(self, x):
244
+ return self.net(x.transpose(1, 2)).transpose(1, 2)
245
+
246
+
247
+ def convnext_factory(dim, n_layers, is_causal=False):
248
+ return ConvNextNet(dim, n_layers, is_causal=is_causal)
249
+
250
+
251
+ def convnextformer_factory(
252
+ dim, n_layers, n_convnext_per_transformer_block, is_causal=False
253
+ ):
254
+ layers = []
255
+ for i in range(0, n_layers, n_convnext_per_transformer_block + 1):
256
+ layers.append(
257
+ ConvNextNet(dim, n_convnext_per_transformer_block, is_causal=is_causal)
258
+ )
259
+ layers.append(TransformerBlock(dim, 64, is_causal=is_causal))
260
+ return nn.Sequential(*layers)
261
+
262
+
263
+ class AutoEncoder(nn.Module):
264
+ def __init__(
265
+ self,
266
+ feat_dim: int,
267
+ hidden_dim: int,
268
+ num_layers: int,
269
+ net_factory: Literal["convnext", "convnextformer_decoder", "convnextformer"],
270
+ out_dim: int | None = None,
271
+ convnextformer_num_conv_per_transformer: int = 3,
272
+ causal_transformer: bool = False,
273
+ bottleneck_size: int | None = None,
274
+ vae: bool = False,
275
+ is_causal: bool = False,
276
+ ):
277
+ super().__init__()
278
+
279
+ self.embed = nn.Linear(feat_dim, hidden_dim)
280
+ if out_dim is None:
281
+ out_dim = feat_dim
282
+ self.unembed = nn.Linear(hidden_dim, out_dim)
283
+
284
+ if net_factory == "convnext":
285
+ self.encoder_net = convnext_factory(
286
+ hidden_dim, num_layers, is_causal=is_causal
287
+ )
288
+ self.decoder_net = convnext_factory(
289
+ hidden_dim, num_layers, is_causal=is_causal
290
+ )
291
+ elif net_factory == "convnextformer_decoder":
292
+ self.encoder_net = convnext_factory(
293
+ hidden_dim, num_layers, is_causal=is_causal
294
+ )
295
+ self.decoder_net = convnextformer_factory(
296
+ hidden_dim,
297
+ num_layers,
298
+ convnextformer_num_conv_per_transformer,
299
+ is_causal=is_causal,
300
+ )
301
+ elif net_factory == "convnextformer":
302
+ self.encoder_net = convnextformer_factory(
303
+ hidden_dim,
304
+ num_layers,
305
+ convnextformer_num_conv_per_transformer,
306
+ is_causal=is_causal,
307
+ )
308
+ self.decoder_net = convnextformer_factory(
309
+ hidden_dim,
310
+ num_layers,
311
+ convnextformer_num_conv_per_transformer,
312
+ is_causal=is_causal,
313
+ )
314
+
315
+ self.bottleneck = (
316
+ nn.Linear(hidden_dim, bottleneck_size * (1 + vae))
317
+ if bottleneck_size is not None
318
+ else nn.Identity()
319
+ )
320
+ self.unbottleneck = (
321
+ nn.Linear(bottleneck_size, hidden_dim)
322
+ if bottleneck_size is not None
323
+ else nn.Identity()
324
+ )
325
+ self.vae = vae
326
+
327
+ def reparameterize(
328
+ self,
329
+ mu: torch.Tensor,
330
+ logvar: torch.Tensor,
331
+ deterministic: bool = False,
332
+ drop_vae_rate: float = 0.0,
333
+ ) -> torch.Tensor:
334
+ logvar = torch.clamp(logvar, -30.0, 20.0)
335
+ std = torch.exp(0.5 * logvar)
336
+ if drop_vae_rate > 0.0:
337
+ to_drop = torch.rand(std.shape[0], device=std.device) < drop_vae_rate
338
+ eps = torch.randn_like(std)
339
+ eps[to_drop] = 0.0
340
+ else:
341
+ if deterministic:
342
+ eps = torch.zeros_like(std)
343
+ else:
344
+ eps = torch.randn_like(std)
345
+ return mu + eps * std
346
+
347
+ def kl_divergence(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
348
+ kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
349
+ return kl.sum(dim=-1).mean()
350
+
351
+ def forward(self, x: torch.Tensor, drop_vae_rate: float = 0.0) -> torch.Tensor:
352
+ # Encode
353
+ x = self.embed(x)
354
+ x = self.encoder_net(x)
355
+ x = self.bottleneck(x)
356
+ if self.vae:
357
+ mu, logvar = x.chunk(2, dim=-1)
358
+ loss = {
359
+ "kl_div": self.kl_divergence(mu, logvar),
360
+ "_mu_mean": mu.mean(),
361
+ "_mu_std": mu.std(),
362
+ "_logvar_mean": logvar.mean(),
363
+ "_logvar_std": logvar.std(),
364
+ }
365
+ x = self.reparameterize(
366
+ mu,
367
+ logvar,
368
+ drop_vae_rate=drop_vae_rate,
369
+ )
370
+ else:
371
+ loss = {}
372
+
373
+ # Decode
374
+ x = self.unbottleneck(x)
375
+ x = self.decoder_net(x)
376
+ x = self.unembed(x)
377
+
378
+ return x, loss
379
+
380
+ def encode(self, x: torch.Tensor, deterministic: bool = False):
381
+ x = self.embed(x)
382
+ x = self.encoder_net(x)
383
+ x = self.bottleneck(x)
384
+
385
+ if self.vae:
386
+ x = self.reparameterize(*x.chunk(2, dim=-1), deterministic=deterministic)
387
+ return x
388
+
389
+ def decode(
390
+ self,
391
+ latent: torch.Tensor | None = None,
392
+ ):
393
+ x = self.unbottleneck(latent)
394
+ x = self.decoder_net(x)
395
+ x = self.unembed(x)
396
+ return x
codec/models/wavvae/__init__.py ADDED
File without changes
codec/models/wavvae/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (170 Bytes). View file
 
codec/models/wavvae/__pycache__/discriminators.cpython-312.pyc ADDED
Binary file (13.1 kB). View file
 
codec/models/wavvae/__pycache__/heads.cpython-312.pyc ADDED
Binary file (10.4 kB). View file
 
codec/models/wavvae/__pycache__/layers.cpython-312.pyc ADDED
Binary file (13.1 kB). View file
 
codec/models/wavvae/__pycache__/loss.cpython-312.pyc ADDED
Binary file (7.01 kB). View file
 
codec/models/wavvae/__pycache__/model.cpython-312.pyc ADDED
Binary file (8.79 kB). View file
 
codec/models/wavvae/__pycache__/modules.cpython-312.pyc ADDED
Binary file (11.1 kB). View file
 
codec/models/wavvae/__pycache__/spectral_ops.cpython-312.pyc ADDED
Binary file (12.4 kB). View file
 
codec/models/wavvae/dataset.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ from pytorch_lightning import LightningDataModule
7
+ from torch.utils.data import DataLoader, Dataset
8
+
9
+ torch.set_num_threads(1)
10
+
11
+
12
+ @dataclass
13
+ class WavVAEDataConfig:
14
+ filelist_path: str
15
+ sampling_rate: int
16
+ num_samples: int
17
+ batch_size: int
18
+ num_workers: int
19
+
20
+
21
+ class WavVAEDataModule(LightningDataModule):
22
+ def __init__(self, train_params: WavVAEDataConfig, val_params: WavVAEDataConfig):
23
+ super().__init__()
24
+ self.train_config = train_params
25
+ self.val_config = val_params
26
+
27
+ def _get_dataloder(self, cfg: DataConfig, train: bool):
28
+ dataset = WavVAEDataset(cfg, train=train)
29
+ dataloader = DataLoader(
30
+ dataset,
31
+ batch_size=cfg.batch_size,
32
+ num_workers=cfg.num_workers,
33
+ shuffle=train,
34
+ pin_memory=True,
35
+ )
36
+ return dataloader
37
+
38
+ def train_dataloader(self) -> DataLoader:
39
+ return self._get_dataloder(self.train_config, train=True)
40
+
41
+ def val_dataloader(self) -> DataLoader:
42
+ return self._get_dataloder(self.val_config, train=False)
43
+
44
+
45
+ class WavVAEDataset(Dataset):
46
+ def __init__(self, cfg: DataConfig, train: bool):
47
+ with open(cfg.filelist_path) as f:
48
+ self.filelist = f.read().splitlines()
49
+ self.sampling_rate = cfg.sampling_rate
50
+ self.num_samples = cfg.num_samples
51
+ self.train = train
52
+
53
+ def __len__(self) -> int:
54
+ return len(self.filelist)
55
+
56
+ def __getitem__(self, index: int) -> torch.Tensor:
57
+ audio_path = self.filelist[index]
58
+ y, sr = torchaudio.load(audio_path)
59
+ if y.size(0) > 1:
60
+ # mix to mono
61
+ y = y.mean(dim=0, keepdim=True)
62
+ gain = np.random.uniform(-1, -6) if self.train else -3
63
+ y, _ = torchaudio.sox_effects.apply_effects_tensor(
64
+ y, sr, [["norm", f"{gain:.2f}"]]
65
+ )
66
+ try:
67
+ if sr != self.sampling_rate:
68
+ y = torchaudio.functional.resample(
69
+ y, orig_freq=sr, new_freq=self.sampling_rate
70
+ )
71
+ except:
72
+ print(audio_path, y.shape)
73
+ if y.size(-1) < self.num_samples:
74
+ pad_length = self.num_samples - y.size(-1)
75
+ padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
76
+ y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
77
+ elif self.train:
78
+ start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
79
+ y = y[:, start : start + self.num_samples]
80
+ else:
81
+ # During validation, take always the first segment for determinism
82
+ y = y[:, : self.num_samples]
83
+
84
+ return y[0]
codec/models/wavvae/discriminators.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import nn
6
+ from torch.nn import Conv2d
7
+ from torch.nn.utils import weight_norm
8
+ from torchaudio.transforms import Spectrogram
9
+
10
+
11
+ class MultiPeriodDiscriminator(nn.Module):
12
+ """
13
+ Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
14
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
15
+
16
+ Args:
17
+ periods (tuple[int]): Tuple of periods for each discriminator.
18
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
19
+ Defaults to None.
20
+ """
21
+
22
+ def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None):
23
+ super().__init__()
24
+ self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])
25
+
26
+ def forward(
27
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
28
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
29
+ y_d_rs = []
30
+ y_d_gs = []
31
+ fmap_rs = []
32
+ fmap_gs = []
33
+ for d in self.discriminators:
34
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
35
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
36
+ y_d_rs.append(y_d_r)
37
+ fmap_rs.append(fmap_r)
38
+ y_d_gs.append(y_d_g)
39
+ fmap_gs.append(fmap_g)
40
+
41
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
42
+
43
+
44
+ class DiscriminatorP(nn.Module):
45
+ def __init__(
46
+ self,
47
+ period: int,
48
+ in_channels: int = 1,
49
+ kernel_size: int = 5,
50
+ stride: int = 3,
51
+ lrelu_slope: float = 0.1,
52
+ num_embeddings: Optional[int] = None,
53
+ ):
54
+ super().__init__()
55
+ self.period = period
56
+ self.convs = nn.ModuleList(
57
+ [
58
+ weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
59
+ weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
60
+ weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
61
+ weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
62
+ weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
63
+ ]
64
+ )
65
+ if num_embeddings is not None:
66
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024)
67
+ torch.nn.init.zeros_(self.emb.weight)
68
+
69
+ self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
70
+ self.lrelu_slope = lrelu_slope
71
+
72
+ def forward(
73
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
74
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
75
+ x = x.unsqueeze(1)
76
+ fmap = []
77
+ # 1d to 2d
78
+ b, c, t = x.shape
79
+ if t % self.period != 0: # pad first
80
+ n_pad = self.period - (t % self.period)
81
+ x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
82
+ t = t + n_pad
83
+ x = x.view(b, c, t // self.period, self.period)
84
+
85
+ for i, l in enumerate(self.convs):
86
+ x = l(x)
87
+ x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
88
+ if i > 0:
89
+ fmap.append(x)
90
+ if cond_embedding_id is not None:
91
+ emb = self.emb(cond_embedding_id)
92
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
93
+ else:
94
+ h = 0
95
+ x = self.conv_post(x)
96
+ fmap.append(x)
97
+ x += h
98
+ x = torch.flatten(x, 1, -1)
99
+
100
+ return x, fmap
101
+
102
+
103
+ class MultiResolutionDiscriminator(nn.Module):
104
+ def __init__(
105
+ self,
106
+ fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
107
+ num_embeddings: Optional[int] = None,
108
+ ):
109
+ """
110
+ Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
111
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
112
+
113
+ Args:
114
+ fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
115
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
116
+ Defaults to None.
117
+ """
118
+
119
+ super().__init__()
120
+ self.discriminators = nn.ModuleList(
121
+ [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
122
+ )
123
+
124
+ def forward(
125
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
126
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
127
+ y_d_rs = []
128
+ y_d_gs = []
129
+ fmap_rs = []
130
+ fmap_gs = []
131
+
132
+ for d in self.discriminators:
133
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
134
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
135
+ y_d_rs.append(y_d_r)
136
+ fmap_rs.append(fmap_r)
137
+ y_d_gs.append(y_d_g)
138
+ fmap_gs.append(fmap_g)
139
+
140
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
141
+
142
+
143
+ class DiscriminatorR(nn.Module):
144
+ def __init__(
145
+ self,
146
+ window_length: int,
147
+ num_embeddings: Optional[int] = None,
148
+ channels: int = 32,
149
+ hop_factor: float = 0.25,
150
+ bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
151
+ ):
152
+ super().__init__()
153
+ self.window_length = window_length
154
+ self.hop_factor = hop_factor
155
+ self.spec_fn = Spectrogram(
156
+ n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
157
+ )
158
+ n_fft = window_length // 2 + 1
159
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
160
+ self.bands = bands
161
+ convs = lambda: nn.ModuleList(
162
+ [
163
+ weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
164
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
165
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
166
+ weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
167
+ weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
168
+ ]
169
+ )
170
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
171
+
172
+ if num_embeddings is not None:
173
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
174
+ torch.nn.init.zeros_(self.emb.weight)
175
+
176
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
177
+
178
+ def spectrogram(self, x):
179
+ # Remove DC offset
180
+ x = x - x.mean(dim=-1, keepdims=True)
181
+ # Peak normalize the volume of input audio
182
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
183
+ x = self.spec_fn(x)
184
+ x = torch.view_as_real(x)
185
+ x = rearrange(x, "b f t c -> b c t f")
186
+ # Split into bands
187
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
188
+ return x_bands
189
+
190
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
191
+ x_bands = self.spectrogram(x)
192
+ fmap = []
193
+ x = []
194
+ for band, stack in zip(x_bands, self.band_convs):
195
+ for i, layer in enumerate(stack):
196
+ band = layer(band)
197
+ band = torch.nn.functional.leaky_relu(band, 0.1)
198
+ if i > 0:
199
+ fmap.append(band)
200
+ x.append(band)
201
+ x = torch.cat(x, dim=-1)
202
+ if cond_embedding_id is not None:
203
+ emb = self.emb(cond_embedding_id)
204
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
205
+ else:
206
+ h = 0
207
+ x = self.conv_post(x)
208
+ fmap.append(x)
209
+ x += h
210
+
211
+ return x, fmap
codec/models/wavvae/experiment.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+
3
+
codec/models/wavvae/heads.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import nn
6
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
7
+
8
+ from .modules import symexp
9
+ from .spectral_ops import IMDCT, ISTFT
10
+
11
+
12
+ class FourierHead(nn.Module):
13
+ """Base class for inverse fourier modules."""
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ """
17
+ Args:
18
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
19
+ L is the sequence length, and H denotes the model dimension.
20
+
21
+ Returns:
22
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
23
+ """
24
+ raise NotImplementedError("Subclasses must implement the forward method.")
25
+
26
+
27
+ class LinearNoBiasHead(FourierHead):
28
+ def __init__(self, dim: int, hop_length: int, n_fft: int):
29
+ super().__init__()
30
+ self.pre_head = nn.Linear(dim, n_fft + 2)
31
+ self.head = nn.Linear(n_fft + 2, hop_length, bias=False)
32
+
33
+ def forward(self, x):
34
+ y = self.pre_head(x)
35
+ y = self.head(y).clamp(min=-1.0, max=1.0)
36
+ B, _, _ = y.shape
37
+ return y.reshape(B, -1)
38
+
39
+
40
+ class ISTFTHead(FourierHead):
41
+ """
42
+ ISTFT Head module for predicting STFT complex coefficients.
43
+
44
+ Args:
45
+ dim (int): Hidden dimension of the model.
46
+ n_fft (int): Size of Fourier transform.
47
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
48
+ the resolution of the input features.
49
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
50
+ """
51
+
52
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
53
+ super().__init__()
54
+ out_dim = n_fft + 2
55
+ self.out = torch.nn.Linear(dim, out_dim)
56
+ self.hop_length = hop_length
57
+ self.istft = ISTFT(
58
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
59
+ )
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ """
63
+ Forward pass of the ISTFTHead module.
64
+
65
+ Args:
66
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
67
+ L is the sequence length, and H denotes the model dimension.
68
+
69
+ Returns:
70
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
71
+ """
72
+ x = self.out(x).transpose(1, 2)
73
+ mag, p = x.chunk(2, dim=1)
74
+ mag = torch.exp(mag)
75
+ mag = torch.clip(
76
+ mag, max=1e2
77
+ ) # safeguard to prevent excessively large magnitudes
78
+ # wrapping happens here. These two lines produce real and imaginary value
79
+ x = torch.cos(p)
80
+ y = torch.sin(p)
81
+ # recalculating phase here does not produce anything new
82
+ # only costs time
83
+ # phase = torch.atan2(y, x)
84
+ # S = mag * torch.exp(phase * 1j)
85
+ # better directly produce the complex value
86
+ S = mag * (x + 1j * y)
87
+ audio = self.istft(S)
88
+ audio = nn.functional.pad(audio, (self.hop_length // 2, self.hop_length // 2))
89
+ return audio
90
+
91
+
92
+ class IMDCTSymExpHead(FourierHead):
93
+ """
94
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
95
+
96
+ Args:
97
+ dim (int): Hidden dimension of the model.
98
+ mdct_frame_len (int): Length of the MDCT frame.
99
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
100
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
101
+ based on perceptual scaling. Defaults to None.
102
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ dim: int,
108
+ mdct_frame_len: int,
109
+ padding: str = "same",
110
+ sample_rate: Optional[int] = None,
111
+ clip_audio: bool = False,
112
+ ):
113
+ super().__init__()
114
+ out_dim = mdct_frame_len // 2
115
+ self.out = nn.Linear(dim, out_dim)
116
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
117
+ self.clip_audio = clip_audio
118
+
119
+ if sample_rate is not None:
120
+ # optionally init the last layer following mel-scale
121
+ m_max = _hz_to_mel(sample_rate // 2)
122
+ m_pts = torch.linspace(0, m_max, out_dim)
123
+ f_pts = _mel_to_hz(m_pts)
124
+ scale = 1 - (f_pts / f_pts.max())
125
+
126
+ with torch.no_grad():
127
+ self.out.weight.mul_(scale.view(-1, 1))
128
+
129
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
130
+ """
131
+ Forward pass of the IMDCTSymExpHead module.
132
+
133
+ Args:
134
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
135
+ L is the sequence length, and H denotes the model dimension.
136
+
137
+ Returns:
138
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
139
+ """
140
+ x = self.out(x)
141
+ x = symexp(x)
142
+ x = torch.clip(
143
+ x, min=-1e2, max=1e2
144
+ ) # safeguard to prevent excessively large magnitudes
145
+ audio = self.imdct(x)
146
+ if self.clip_audio:
147
+ audio = torch.clip(x, min=-1.0, max=1.0)
148
+
149
+ return audio
150
+
151
+
152
+ class IMDCTCosHead(FourierHead):
153
+ """
154
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
155
+
156
+ Args:
157
+ dim (int): Hidden dimension of the model.
158
+ mdct_frame_len (int): Length of the MDCT frame.
159
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
160
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ dim: int,
166
+ mdct_frame_len: int,
167
+ padding: str = "same",
168
+ clip_audio: bool = False,
169
+ ):
170
+ super().__init__()
171
+ self.clip_audio = clip_audio
172
+ self.out = nn.Linear(dim, mdct_frame_len)
173
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
174
+
175
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
176
+ """
177
+ Forward pass of the IMDCTCosHead module.
178
+
179
+ Args:
180
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
181
+ L is the sequence length, and H denotes the model dimension.
182
+
183
+ Returns:
184
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
185
+ """
186
+ x = self.out(x)
187
+ m, p = x.chunk(2, dim=2)
188
+ m = torch.exp(m).clip(
189
+ max=1e2
190
+ ) # safeguard to prevent excessively large magnitudes
191
+ audio = self.imdct(m * torch.cos(p))
192
+ if self.clip_audio:
193
+ audio = torch.clip(x, min=-1.0, max=1.0)
194
+ return audio
codec/models/wavvae/helpers.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import numpy as np
3
+ import torch
4
+ from matplotlib import pyplot as plt
5
+ from pytorch_lightning import Callback
6
+
7
+ matplotlib.use("Agg")
8
+
9
+
10
+ def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
11
+ """
12
+ Save a matplotlib figure to a numpy array.
13
+
14
+ Args:
15
+ fig (Figure): Matplotlib figure object.
16
+
17
+ Returns:
18
+ ndarray: Numpy array representing the figure.
19
+ """
20
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
21
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
22
+ return data
23
+
24
+
25
+ def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
26
+ """
27
+ Plot a spectrogram and convert it to a numpy array.
28
+
29
+ Args:
30
+ spectrogram (ndarray): Spectrogram data.
31
+
32
+ Returns:
33
+ ndarray: Numpy array representing the plotted spectrogram.
34
+ """
35
+ spectrogram = spectrogram.astype(np.float32)
36
+ fig, ax = plt.subplots(figsize=(12, 3))
37
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
38
+ plt.colorbar(im, ax=ax)
39
+ plt.xlabel("Frames")
40
+ plt.ylabel("Channels")
41
+ plt.tight_layout()
42
+
43
+ fig.canvas.draw()
44
+ data = save_figure_to_numpy(fig)
45
+ plt.close()
46
+ return data
47
+
48
+
49
+ class GradNormCallback(Callback):
50
+ """
51
+ Callback to log the gradient norm.
52
+ """
53
+
54
+ def on_after_backward(self, trainer, model):
55
+ model.log("grad_norm", gradient_norm(model))
56
+
57
+
58
+ def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
59
+ """
60
+ Compute the gradient norm.
61
+
62
+ Args:
63
+ model (Module): PyTorch model.
64
+ norm_type (float, optional): Type of the norm. Defaults to 2.0.
65
+
66
+ Returns:
67
+ Tensor: Gradient norm.
68
+ """
69
+ grads = [p.grad for p in model.parameters() if p.grad is not None]
70
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
71
+ return total_norm
codec/models/wavvae/layers.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.utils.parametrizations import weight_norm
6
+
7
+ class VocosDecoder(nn.Module):
8
+ def __init__(
9
+ self,
10
+ dim: int,
11
+ intermediate_dim: int,
12
+ num_layers: int,
13
+ ):
14
+ super().__init__()
15
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
16
+ self.convnext = nn.ModuleList(
17
+ [
18
+ ConvNeXtBlock(
19
+ dim=dim,
20
+ intermediate_dim=intermediate_dim,
21
+ layer_scale_init_value=0.0,
22
+ )
23
+ for _ in range(num_layers)
24
+ ]
25
+ )
26
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ x = self.norm(x)
30
+ x = x.transpose(1, 2)
31
+ for conv_block in self.convnext:
32
+ x = conv_block(x)
33
+ x = self.final_layer_norm(x.transpose(1, 2))
34
+ return x
35
+
36
+ class ConvNeXtBlock(nn.Module):
37
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
38
+
39
+ Args:
40
+ dim (int): Number of input channels.
41
+ intermediate_dim (int): Dimensionality of the intermediate layer.
42
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
43
+ Defaults to None.
44
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
45
+ None means non-conditional LayerNorm. Defaults to None.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ dim: int,
51
+ intermediate_dim: int | None = None,
52
+ layer_scale_init_value: float = 0.0,
53
+ elementwise_affine_ln: bool = True,
54
+ ):
55
+ super().__init__()
56
+ intermediate_dim = intermediate_dim if intermediate_dim is not None else dim * 3
57
+ self.dwconv = nn.Conv1d(
58
+ dim, dim, kernel_size=7, padding=3, groups=dim
59
+ ) # depthwise conv
60
+ self.norm = nn.LayerNorm(
61
+ dim, eps=1e-6, elementwise_affine=elementwise_affine_ln
62
+ )
63
+ self.pwconv1 = nn.Linear(
64
+ dim, intermediate_dim
65
+ ) # pointwise/1x1 convs, implemented with linear layers
66
+ self.act = nn.GELU()
67
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
68
+ self.gamma = (
69
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
70
+ if layer_scale_init_value > 0
71
+ else None
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ x: torch.Tensor,
77
+ scale_shift: tuple[torch.Tensor, torch.Tensor] | None = None,
78
+ gate: torch.Tensor | None = None,
79
+ ) -> torch.Tensor:
80
+ residual = x
81
+ x = self.dwconv(x)
82
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
83
+ x = self.norm(x)
84
+ if scale_shift is not None:
85
+ scale, shift = scale_shift
86
+ x = x * scale[:, None] + shift[:, None]
87
+ x = self.pwconv1(x)
88
+ x = self.act(x)
89
+ x = self.pwconv2(x)
90
+ if self.gamma is not None:
91
+ x = self.gamma * x
92
+ if gate is not None:
93
+ x = gate[:, None] * x
94
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
95
+
96
+ x = residual + x
97
+ return x
98
+
99
+
100
+ class Encoder(nn.Module):
101
+ def __init__(
102
+ self,
103
+ d_model=32,
104
+ strides=[2, 4, 4, 8],
105
+ depthwise=False,
106
+ ):
107
+ super().__init__()
108
+ layers = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
109
+ for stride in strides:
110
+ d_model *= 2
111
+ groups = d_model // 2 if depthwise else 1
112
+ layers += [EncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
113
+ groups = d_model if depthwise else 1
114
+ layers += [
115
+ WNConv1d(d_model, d_model, kernel_size=7, padding=3, groups=groups),
116
+ ]
117
+ self.block = nn.Sequential(*layers)
118
+
119
+ def forward(self, x):
120
+ return self.block(x)
121
+
122
+
123
+ class Decoder(nn.Module):
124
+ def __init__(
125
+ self,
126
+ input_channel,
127
+ channels,
128
+ rates,
129
+ noise=False,
130
+ depthwise=False,
131
+ d_out=1,
132
+ ):
133
+ super().__init__()
134
+ if depthwise:
135
+ layers = [
136
+ WNConv1d(
137
+ input_channel,
138
+ input_channel,
139
+ kernel_size=7,
140
+ padding=3,
141
+ groups=input_channel,
142
+ ),
143
+ WNConv1d(input_channel, channels, kernel_size=1),
144
+ ]
145
+ else:
146
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
147
+
148
+ for i, stride in enumerate(rates):
149
+ input_dim = channels // 2**i
150
+ output_dim = channels // 2 ** (i + 1)
151
+ groups = output_dim if depthwise else 1
152
+ layers.append(
153
+ DecoderBlock(input_dim, output_dim, stride, noise, groups=groups)
154
+ )
155
+
156
+ layers += [
157
+ Snake1d(output_dim),
158
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
159
+ nn.Tanh(),
160
+ ]
161
+ self.model = nn.Sequential(*layers)
162
+
163
+ def forward(self, x):
164
+ x = self.model(x)
165
+ return x
166
+
167
+
168
+ class ResidualUnit(nn.Module):
169
+ def __init__(self, dim=16, dilation=1, kernel=7, groups=1):
170
+ super().__init__()
171
+ pad = ((kernel - 1) * dilation) // 2
172
+ self.block = nn.Sequential(
173
+ Snake1d(dim),
174
+ WNConv1d(
175
+ dim,
176
+ dim,
177
+ kernel_size=kernel,
178
+ dilation=dilation,
179
+ padding=pad,
180
+ groups=groups,
181
+ ),
182
+ Snake1d(dim),
183
+ WNConv1d(dim, dim, kernel_size=1),
184
+ )
185
+
186
+ def forward(self, x):
187
+ y = self.block(x)
188
+ pad = (x.shape[-1] - y.shape[-1]) // 2
189
+ if pad > 0:
190
+ x = x[..., pad:-pad]
191
+ return x + y
192
+
193
+
194
+ class EncoderBlock(nn.Module):
195
+ def __init__(self, output_dim=16, input_dim=None, stride=1, groups=1):
196
+ super().__init__()
197
+ input_dim = input_dim or output_dim // 2
198
+ self.block = nn.Sequential(
199
+ ResidualUnit(input_dim, dilation=1, groups=groups),
200
+ ResidualUnit(input_dim, dilation=3, groups=groups),
201
+ ResidualUnit(input_dim, dilation=9, groups=groups),
202
+ Snake1d(input_dim),
203
+ WNConv1d(
204
+ input_dim,
205
+ output_dim,
206
+ kernel_size=2 * stride,
207
+ stride=stride,
208
+ padding=math.ceil(stride / 2),
209
+ ),
210
+ )
211
+
212
+ def forward(self, x):
213
+ return self.block(x)
214
+
215
+
216
+ class NoiseBlock(nn.Module):
217
+ def __init__(self, dim):
218
+ super().__init__()
219
+ self.linear = WNConv1d(dim, dim, kernel_size=1, bias=False)
220
+
221
+ def forward(self, x):
222
+ B, C, T = x.shape
223
+ noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
224
+ h = self.linear(x)
225
+ n = noise * h
226
+ x = x + n
227
+ return x
228
+
229
+
230
+ class DecoderBlock(nn.Module):
231
+ def __init__(self, input_dim=16, output_dim=8, stride=1, noise=False, groups=1):
232
+ super().__init__()
233
+ layers = [
234
+ Snake1d(input_dim),
235
+ WNConvTranspose1d(
236
+ input_dim,
237
+ output_dim,
238
+ kernel_size=2 * stride,
239
+ stride=stride,
240
+ padding=math.ceil(stride / 2),
241
+ output_padding=stride % 2,
242
+ ),
243
+ ]
244
+ if noise:
245
+ layers.append(NoiseBlock(output_dim))
246
+ layers.extend(
247
+ [
248
+ ResidualUnit(output_dim, dilation=1, groups=groups),
249
+ ResidualUnit(output_dim, dilation=3, groups=groups),
250
+ ResidualUnit(output_dim, dilation=9, groups=groups),
251
+ ]
252
+ )
253
+ self.block = nn.Sequential(*layers)
254
+
255
+ def forward(self, x):
256
+ return self.block(x)
257
+
258
+
259
+ def WNConv1d(*args, **kwargs):
260
+ return weight_norm(nn.Conv1d(*args, **kwargs))
261
+
262
+
263
+ def WNConvTranspose1d(*args, **kwargs):
264
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
265
+
266
+
267
+ @torch.jit.script
268
+ def snake(x, alpha):
269
+ shape = x.shape
270
+ x = x.reshape(shape[0], shape[1], -1)
271
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
272
+ x = x.reshape(shape)
273
+ return x
274
+
275
+
276
+ class Snake1d(nn.Module):
277
+ def __init__(self, channels):
278
+ super().__init__()
279
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
280
+
281
+ def forward(self, x):
282
+ return snake(x, self.alpha)
codec/models/wavvae/loss.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import torch
4
+ import torchaudio
5
+ from torch import nn
6
+
7
+ from .modules import safe_log
8
+
9
+
10
+ class MelSpecReconstructionLoss(nn.Module):
11
+ """
12
+ L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ sample_rate: int = 24000,
18
+ n_fft: int | None = None,
19
+ hop_length: int = 256,
20
+ n_mels: int = 100,
21
+ f_min: int = 0,
22
+ f_max: Optional[int] = None,
23
+ ):
24
+ super().__init__()
25
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
26
+ sample_rate=sample_rate,
27
+ n_fft=hop_length * 4 if n_fft is None else n_fft,
28
+ hop_length=hop_length,
29
+ n_mels=n_mels,
30
+ center=True,
31
+ power=1,
32
+ f_min=f_min,
33
+ f_max=f_max,
34
+ )
35
+
36
+ def forward(self, y_hat, y) -> torch.Tensor:
37
+ """
38
+ Args:
39
+ y_hat (Tensor): Predicted audio waveform.
40
+ y (Tensor): Ground truth audio waveform.
41
+
42
+ Returns:
43
+ Tensor: L1 loss between the mel-scaled magnitude spectrograms.
44
+ """
45
+ # B, C, Th = y_hat.shape
46
+ # B, C, T = y.shape
47
+ # crop = (Th - T) // 2
48
+ mel_hat = safe_log(self.mel_spec(y_hat))
49
+ # mel_hat = safe_log(self.mel_spec(y_hat[..., crop:-crop]))
50
+ # mel = safe_log(self.mel_spec(y[..., crop:-crop]))
51
+ mel = safe_log(self.mel_spec(y))
52
+
53
+ loss = torch.nn.functional.l1_loss(mel, mel_hat)
54
+
55
+ return loss
56
+
57
+
58
+ class GeneratorLoss(nn.Module):
59
+ """
60
+ Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
61
+ """
62
+
63
+ def forward(
64
+ self, disc_outputs: List[torch.Tensor]
65
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
66
+ """
67
+ Args:
68
+ disc_outputs (List[Tensor]): List of discriminator outputs.
69
+
70
+ Returns:
71
+ Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
72
+ the sub-discriminators
73
+ """
74
+ loss = torch.zeros(
75
+ 1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype
76
+ )
77
+ gen_losses = []
78
+ for dg in disc_outputs:
79
+ l = torch.mean(torch.clamp(1 - dg, min=0))
80
+ gen_losses.append(l)
81
+ loss += l
82
+
83
+ return loss, gen_losses
84
+
85
+
86
+ class DiscriminatorLoss(nn.Module):
87
+ """
88
+ Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
89
+ """
90
+
91
+ def forward(
92
+ self,
93
+ disc_real_outputs: List[torch.Tensor],
94
+ disc_generated_outputs: List[torch.Tensor],
95
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
96
+ """
97
+ Args:
98
+ disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
99
+ disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
100
+
101
+ Returns:
102
+ Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
103
+ the sub-discriminators for real outputs, and a list of
104
+ loss values for generated outputs.
105
+ """
106
+ loss = torch.zeros(
107
+ 1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype
108
+ )
109
+ r_losses = []
110
+ g_losses = []
111
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
112
+ r_loss = torch.mean(torch.clamp(1 - dr, min=0))
113
+ g_loss = torch.mean(torch.clamp(1 + dg, min=0))
114
+ loss += r_loss + g_loss
115
+ r_losses.append(r_loss)
116
+ g_losses.append(g_loss)
117
+
118
+ return loss, r_losses, g_losses
119
+
120
+
121
+ class FeatureMatchingLoss(nn.Module):
122
+ """
123
+ Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
124
+ """
125
+
126
+ def forward(
127
+ self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
128
+ ) -> torch.Tensor:
129
+ """
130
+ Args:
131
+ fmap_r (List[List[Tensor]]): List of feature maps from real samples.
132
+ fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
133
+
134
+ Returns:
135
+ Tensor: The calculated feature matching loss.
136
+ """
137
+ loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
138
+ for dr, dg in zip(fmap_r, fmap_g):
139
+ for rl, gl in zip(dr, dg):
140
+ loss += torch.mean(torch.abs(rl - gl))
141
+
142
+ return loss
codec/models/wavvae/model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ from typing import Literal
5
+
6
+ import torch
7
+ from safetensors.torch import load_file
8
+ from torch import nn
9
+
10
+ from .heads import ISTFTHead, LinearNoBiasHead
11
+ from .layers import Encoder, VocosDecoder
12
+ from .modules import ConvNeXtBlock
13
+
14
+
15
+ @dataclass
16
+ class WavVAEConfig:
17
+ conv_dim: int = 48
18
+ latent_dim: int = 32
19
+ decoder_hidden_dim: int = 768
20
+ decoder_intermediate_dim: int = 1536
21
+ decoder_num_layers: int = 8
22
+ n_fft: int = 1024
23
+ hop_length: int = 256
24
+ padding: str = "center"
25
+ head_type: Literal["istft", "linear"] = "istft"
26
+ strides: list[int] = field(default_factory=lambda: [2, 4, 4, 8])
27
+ learnable_pre_norm: bool = False
28
+ sampling_rate: int = 24000
29
+
30
+
31
+ class WavVAE(nn.Module):
32
+ def __init__(self, cfg: WavVAEConfig):
33
+ super().__init__()
34
+ self.conv_encoder = Encoder(cfg.conv_dim, strides=cfg.strides, depthwise=True)
35
+ conv_final_dim = cfg.conv_dim * 2 ** len(cfg.strides)
36
+ self.bottleneck = nn.Linear(conv_final_dim, cfg.latent_dim * 2)
37
+ self.unbottleneck = nn.Linear(cfg.latent_dim, cfg.decoder_hidden_dim)
38
+ self.latent_norm = nn.LayerNorm(conv_final_dim)
39
+ self.vocos_decoder = VocosDecoder(
40
+ cfg.decoder_hidden_dim,
41
+ cfg.decoder_intermediate_dim,
42
+ cfg.decoder_num_layers,
43
+ )
44
+ if cfg.head_type == "istft":
45
+ self.head = ISTFTHead(
46
+ cfg.decoder_hidden_dim,
47
+ cfg.n_fft,
48
+ cfg.hop_length,
49
+ padding=cfg.padding,
50
+ )
51
+ elif cfg.head_type == "linear":
52
+ self.head = LinearNoBiasHead(
53
+ cfg.decoder_hidden_dim,
54
+ cfg.hop_length,
55
+ cfg.n_fft,
56
+ )
57
+
58
+ self._sampling_rate = cfg.sampling_rate
59
+ self._strides = cfg.strides
60
+ self.apply(self._init_weights)
61
+
62
+ def _init_weights(self, m):
63
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
64
+ nn.init.trunc_normal_(m.weight, std=0.02)
65
+ if m.bias is not None:
66
+ nn.init.constant_(m.bias, 0)
67
+
68
+ @property
69
+ def sampling_rate(self) -> int:
70
+ return self._sampling_rate
71
+
72
+ @property
73
+ def hop_length(self) -> int:
74
+ hop_length = 1
75
+ for s in self._strides:
76
+ hop_length *= s
77
+ return hop_length
78
+
79
+ @property
80
+ def frame_rate(self) -> float:
81
+ return self.sampling_rate / self.hop_length
82
+
83
+ @classmethod
84
+ def from_pretrained(
85
+ cls,
86
+ pretrained_model_name_or_path: str,
87
+ device: str = "cpu",
88
+ ):
89
+ if Path(pretrained_model_name_or_path).exists():
90
+ path = pretrained_model_name_or_path
91
+ else:
92
+ from huggingface_hub import snapshot_download
93
+
94
+ path = snapshot_download(pretrained_model_name_or_path)
95
+
96
+ with open(Path(path) / "config.json", "r") as f:
97
+ config = json.load(f)
98
+ config = WavVAEConfig(**config)
99
+ model = cls(config)
100
+ state_dict = load_file(
101
+ Path(path) / "model.st",
102
+ device=device,
103
+ )
104
+
105
+ model.load_state_dict(state_dict, assign=True)
106
+ return model
107
+
108
+ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
109
+ logvar = torch.clamp(logvar, -30.0, 20.0)
110
+ std = torch.exp(0.5 * logvar)
111
+ eps = torch.randn_like(std)
112
+ return mu + eps * std
113
+
114
+ def kl_divergence(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
115
+ kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
116
+ return kl.sum(dim=-1).mean()
117
+
118
+ def encode(self, audio: torch.Tensor) -> torch.Tensor:
119
+ y = self.conv_encoder(audio.unsqueeze(1)).transpose(1, 2)
120
+ y = self.latent_norm(y)
121
+ mu, logvar = self.bottleneck(y).chunk(2, dim=-1)
122
+ z = self.reparameterize(mu, logvar)
123
+ return z
124
+
125
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
126
+ y = self.unbottleneck(z)
127
+ y = self.vocos_decoder(y)
128
+ return self.head(y)
129
+
130
+ def forward(self, audio_input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
131
+ y = self.conv_encoder(audio_input.unsqueeze(1)).transpose(1, 2)
132
+ y = self.latent_norm(y)
133
+ mu, logvar = self.bottleneck(y).chunk(2, dim=-1)
134
+ kl_div = self.kl_divergence(mu, logvar)
135
+ z = self.reparameterize(mu, logvar)
136
+ y = self.unbottleneck(z)
137
+ y = self.vocos_decoder(y)
138
+ audio_output = self.head(y)
139
+
140
+ return audio_output, kl_div
codec/models/wavvae/modules.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+
8
+ class ConvNeXtBlock(nn.Module):
9
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
10
+
11
+ Args:
12
+ dim (int): Number of input channels.
13
+ intermediate_dim (int): Dimensionality of the intermediate layer.
14
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
15
+ Defaults to None.
16
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
17
+ None means non-conditional LayerNorm. Defaults to None.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ dim: int,
23
+ intermediate_dim: int,
24
+ layer_scale_init_value: float,
25
+ adanorm_num_embeddings: Optional[int] = None,
26
+ ):
27
+ super().__init__()
28
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
29
+ self.adanorm = adanorm_num_embeddings is not None
30
+ if adanorm_num_embeddings:
31
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
32
+ else:
33
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
34
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
35
+ self.act = nn.GELU()
36
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
37
+ self.gamma = (
38
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
39
+ if layer_scale_init_value > 0
40
+ else None
41
+ )
42
+
43
+ def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
44
+ residual = x
45
+ x = self.dwconv(x)
46
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
47
+ if self.adanorm:
48
+ assert cond_embedding_id is not None
49
+ x = self.norm(x, cond_embedding_id)
50
+ else:
51
+ x = self.norm(x)
52
+ x = self.pwconv1(x)
53
+ x = self.act(x)
54
+ x = self.pwconv2(x)
55
+ if self.gamma is not None:
56
+ x = self.gamma * x
57
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
58
+
59
+ x = residual + x
60
+ return x
61
+
62
+
63
+ class AdaLayerNorm(nn.Module):
64
+ """
65
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
66
+
67
+ Args:
68
+ num_embeddings (int): Number of embeddings.
69
+ embedding_dim (int): Dimension of the embeddings.
70
+ """
71
+
72
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
73
+ super().__init__()
74
+ self.eps = eps
75
+ self.dim = embedding_dim
76
+ self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
77
+ self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
78
+ torch.nn.init.ones_(self.scale.weight)
79
+ torch.nn.init.zeros_(self.shift.weight)
80
+
81
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
82
+ scale = self.scale(cond_embedding_id)
83
+ shift = self.shift(cond_embedding_id)
84
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
85
+ x = x * scale + shift
86
+ return x
87
+
88
+
89
+ class ResBlock1(nn.Module):
90
+ """
91
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
92
+ but without upsampling layers.
93
+
94
+ Args:
95
+ dim (int): Number of input channels.
96
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
97
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
98
+ Defaults to (1, 3, 5).
99
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
100
+ Defaults to 0.1.
101
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
102
+ Defaults to None.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ dim: int,
108
+ kernel_size: int = 3,
109
+ dilation: Tuple[int, int, int] = (1, 3, 5),
110
+ lrelu_slope: float = 0.1,
111
+ layer_scale_init_value: Optional[float] = None,
112
+ ):
113
+ super().__init__()
114
+ self.lrelu_slope = lrelu_slope
115
+ self.convs1 = nn.ModuleList(
116
+ [
117
+ weight_norm(
118
+ nn.Conv1d(
119
+ dim,
120
+ dim,
121
+ kernel_size,
122
+ 1,
123
+ dilation=dilation[0],
124
+ padding=self.get_padding(kernel_size, dilation[0]),
125
+ )
126
+ ),
127
+ weight_norm(
128
+ nn.Conv1d(
129
+ dim,
130
+ dim,
131
+ kernel_size,
132
+ 1,
133
+ dilation=dilation[1],
134
+ padding=self.get_padding(kernel_size, dilation[1]),
135
+ )
136
+ ),
137
+ weight_norm(
138
+ nn.Conv1d(
139
+ dim,
140
+ dim,
141
+ kernel_size,
142
+ 1,
143
+ dilation=dilation[2],
144
+ padding=self.get_padding(kernel_size, dilation[2]),
145
+ )
146
+ ),
147
+ ]
148
+ )
149
+
150
+ self.convs2 = nn.ModuleList(
151
+ [
152
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
153
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
154
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
155
+ ]
156
+ )
157
+
158
+ self.gamma = nn.ParameterList(
159
+ [
160
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
161
+ if layer_scale_init_value is not None
162
+ else None,
163
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
164
+ if layer_scale_init_value is not None
165
+ else None,
166
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
167
+ if layer_scale_init_value is not None
168
+ else None,
169
+ ]
170
+ )
171
+
172
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
173
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
174
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
175
+ xt = c1(xt)
176
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
177
+ xt = c2(xt)
178
+ if gamma is not None:
179
+ xt = gamma * xt
180
+ x = xt + x
181
+ return x
182
+
183
+ def remove_weight_norm(self):
184
+ for l in self.convs1:
185
+ remove_weight_norm(l)
186
+ for l in self.convs2:
187
+ remove_weight_norm(l)
188
+
189
+ @staticmethod
190
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
191
+ return int((kernel_size * dilation - dilation) / 2)
192
+
193
+
194
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
195
+ """
196
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
197
+
198
+ Args:
199
+ x (Tensor): Input tensor.
200
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
201
+
202
+ Returns:
203
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
204
+ """
205
+ return torch.log(torch.clip(x, min=clip_val))
206
+
207
+
208
+ def symlog(x: torch.Tensor) -> torch.Tensor:
209
+ return torch.sign(x) * torch.log1p(x.abs())
210
+
211
+
212
+ def symexp(x: torch.Tensor) -> torch.Tensor:
213
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
codec/models/wavvae/spectral_ops.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ import torch
4
+ from torch import nn, view_as_real, view_as_complex
5
+
6
+
7
+ class ISTFT(nn.Module):
8
+ """
9
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
10
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
11
+ See issue: https://github.com/pytorch/pytorch/issues/62323
12
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
13
+ The NOLA constraint is met as we trim padded samples anyway.
14
+
15
+ Args:
16
+ n_fft (int): Size of Fourier transform.
17
+ hop_length (int): The distance between neighboring sliding window frames.
18
+ win_length (int): The size of window frame and STFT filter.
19
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
20
+ """
21
+
22
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
23
+ super().__init__()
24
+ if padding not in ["center", "same"]:
25
+ raise ValueError("Padding must be 'center' or 'same'.")
26
+ self.padding = padding
27
+ self.n_fft = n_fft
28
+ self.hop_length = hop_length
29
+ self.win_length = win_length
30
+ window = torch.hann_window(win_length)
31
+ self.register_buffer("window", window)
32
+
33
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
34
+ """
35
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
36
+
37
+ Args:
38
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
39
+ N is the number of frequency bins, and T is the number of time frames.
40
+
41
+ Returns:
42
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
43
+ """
44
+ if self.padding == "center":
45
+ # Fallback to pytorch native implementation
46
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
47
+ elif self.padding == "same":
48
+ pad = (self.win_length - self.hop_length) // 2
49
+ else:
50
+ raise ValueError("Padding must be 'center' or 'same'.")
51
+
52
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
53
+ B, N, T = spec.shape
54
+
55
+ # Inverse FFT
56
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
57
+ ifft = ifft * self.window[None, :, None]
58
+
59
+ # Overlap and Add
60
+ output_size = (T - 1) * self.hop_length + self.win_length
61
+ y = torch.nn.functional.fold(
62
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
63
+ )[:, 0, 0, pad:-pad]
64
+
65
+ # Window envelope
66
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
67
+ window_envelope = torch.nn.functional.fold(
68
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
69
+ ).squeeze()[pad:-pad]
70
+
71
+ # Normalize
72
+ assert (window_envelope > 1e-11).all()
73
+ y = y / window_envelope
74
+
75
+ return y
76
+
77
+
78
+ class MDCT(nn.Module):
79
+ """
80
+ Modified Discrete Cosine Transform (MDCT) module.
81
+
82
+ Args:
83
+ frame_len (int): Length of the MDCT frame.
84
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
85
+ """
86
+
87
+ def __init__(self, frame_len: int, padding: str = "same"):
88
+ super().__init__()
89
+ if padding not in ["center", "same"]:
90
+ raise ValueError("Padding must be 'center' or 'same'.")
91
+ self.padding = padding
92
+ self.frame_len = frame_len
93
+ N = frame_len // 2
94
+ n0 = (N + 1) / 2
95
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
96
+ self.register_buffer("window", window)
97
+
98
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
99
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
100
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
101
+ # https://github.com/pytorch/pytorch/issues/71613
102
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
103
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
104
+
105
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
106
+ """
107
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
108
+
109
+ Args:
110
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
111
+ and T is the length of the audio.
112
+
113
+ Returns:
114
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
115
+ and N is the number of frequency bins.
116
+ """
117
+ if self.padding == "center":
118
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
119
+ elif self.padding == "same":
120
+ # hop_length is 1/2 frame_len
121
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
122
+ else:
123
+ raise ValueError("Padding must be 'center' or 'same'.")
124
+
125
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
126
+ N = self.frame_len // 2
127
+ x = x * self.window.expand(x.shape)
128
+ X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
129
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
130
+ return torch.real(res) * np.sqrt(2)
131
+
132
+
133
+ class IMDCT(nn.Module):
134
+ """
135
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
136
+
137
+ Args:
138
+ frame_len (int): Length of the MDCT frame.
139
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
140
+ """
141
+
142
+ def __init__(self, frame_len: int, padding: str = "same"):
143
+ super().__init__()
144
+ if padding not in ["center", "same"]:
145
+ raise ValueError("Padding must be 'center' or 'same'.")
146
+ self.padding = padding
147
+ self.frame_len = frame_len
148
+ N = frame_len // 2
149
+ n0 = (N + 1) / 2
150
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
151
+ self.register_buffer("window", window)
152
+
153
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
154
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
155
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
156
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
157
+
158
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
159
+ """
160
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
161
+
162
+ Args:
163
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
164
+ L is the number of frames, and N is the number of frequency bins.
165
+
166
+ Returns:
167
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
168
+ """
169
+ B, L, N = X.shape
170
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
171
+ Y[..., :N] = X
172
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
173
+ y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
174
+ y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
175
+ result = y * self.window.expand(y.shape)
176
+ output_size = (1, (L + 1) * N)
177
+ audio = torch.nn.functional.fold(
178
+ result.transpose(1, 2),
179
+ output_size=output_size,
180
+ kernel_size=(1, self.frame_len),
181
+ stride=(1, self.frame_len // 2),
182
+ )[:, 0, 0, :]
183
+
184
+ if self.padding == "center":
185
+ pad = self.frame_len // 2
186
+ elif self.padding == "same":
187
+ pad = self.frame_len // 4
188
+ else:
189
+ raise ValueError("Padding must be 'center' or 'same'.")
190
+
191
+ audio = audio[:, pad:-pad]
192
+ return audio
codec/scripts/compare_codecs.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import json
4
+ import os
5
+ import sys
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional, Any, Tuple
9
+
10
+ import torch
11
+ from torchaudio import load as ta_load
12
+ from torchaudio.functional import resample as ta_resample
13
+ import torchaudio
14
+
15
+ # Your libs
16
+ from zcodec.models import WavVAE, ZFlowAutoEncoder
17
+
18
+
19
+ # -------------------------
20
+ # Data structures
21
+ # -------------------------
22
+
23
+
24
+ @dataclass
25
+ class DecodeParams:
26
+ num_steps: int = 10
27
+ cfg: float = 2.0
28
+
29
+
30
+ @dataclass
31
+ class ModelPairSpec:
32
+ name: str
33
+ wavvae_dir: str
34
+ zflowae_dir: str
35
+ decode: DecodeParams
36
+
37
+
38
+ # -------------------------
39
+ # Utilities
40
+ # -------------------------
41
+
42
+
43
+ def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]:
44
+ if path.is_file():
45
+ try:
46
+ with path.open("r", encoding="utf-8") as f:
47
+ return json.load(f)
48
+ except Exception:
49
+ return None
50
+ return None
51
+
52
+
53
+ def read_config_any(checkpoint_dir: str) -> Dict[str, Any]:
54
+ """
55
+ Try to read config.json (or a few common fallbacks) from a checkpoint dir.
56
+ Returns {} if nothing could be parsed.
57
+ """
58
+ cand = [
59
+ Path(checkpoint_dir) / "config.json",
60
+ Path(checkpoint_dir)
61
+ / "config.yaml", # won't parse yaml here, we only display path
62
+ Path(checkpoint_dir) / "model_config.json",
63
+ ]
64
+ for p in cand:
65
+ if p.exists():
66
+ if p.suffix == ".json":
67
+ j = load_json_if_exists(p)
68
+ if j is not None:
69
+ return j
70
+ else:
71
+ # For YAML or unknown, just show filename rather than failing
72
+ return {"_config_file": str(p)}
73
+ return {}
74
+
75
+
76
+ def sanitize_name(s: str) -> str:
77
+ return "".join(c if c.isalnum() or c in "-_." else "_" for c in s)
78
+
79
+
80
+ def ensure_mono_and_resample(
81
+ wav: torch.Tensor, sr: int, target_sr: int
82
+ ) -> Tuple[torch.Tensor, int]:
83
+ """
84
+ wav: (channels, samples)
85
+ returns mono float32 in [-1,1], resampled to target_sr
86
+ """
87
+ if wav.ndim != 2:
88
+ raise ValueError(f"Expected 2D waveform (C, T), got shape {tuple(wav.shape)}")
89
+ # to mono
90
+ if wav.size(0) > 1:
91
+ wav = wav.mean(dim=0, keepdim=True)
92
+ # resample if needed
93
+ if sr != target_sr:
94
+ wav = ta_resample(wav, sr, target_sr)
95
+ sr = target_sr
96
+ return wav.to(torch.float32), sr
97
+
98
+
99
+ def save_wav(path: Path, wav: torch.Tensor, sr: int):
100
+ path.parent.mkdir(parents=True, exist_ok=True)
101
+ # (C, T)
102
+ if wav.ndim == 1:
103
+ wav = wav.unsqueeze(0)
104
+ # Clamp to [-1,1]
105
+ wav = wav.clamp(-1, 1).contiguous().cpu()
106
+ torchaudio.save(
107
+ str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16
108
+ )
109
+
110
+
111
+ # -------------------------
112
+ # Core inference
113
+ # -------------------------
114
+
115
+
116
+ @torch.inference_mode()
117
+ def reconstruct_full_pipeline(
118
+ wav_mono: torch.Tensor,
119
+ sr: int,
120
+ wavvae: WavVAE,
121
+ zflowae: ZFlowAutoEncoder,
122
+ decode_params: DecodeParams,
123
+ device: str,
124
+ ) -> torch.Tensor:
125
+ """
126
+ Full path: audio -> WavVAE.encode -> ZFlowAE.encode -> ZFlowAE.decode -> WavVAE.decode -> audio_hat
127
+ """
128
+ wav_mono = wav_mono.to(device)
129
+ # WavVAE expects (B, C, T); assume C=1
130
+ x = wav_mono.unsqueeze(0) # (1, 1, T)
131
+ # Encode to high-framerate latents
132
+ z = wavvae.encode(x)
133
+ # Compress latents
134
+ y = zflowae.encode(z)
135
+ # Decompress
136
+ z_hat = zflowae.decode(y, num_steps=decode_params.num_steps, cfg=decode_params.cfg)
137
+ # Decode to waveform
138
+ wav_hat = wavvae.decode(z_hat) # (1, 1, T)
139
+ # Return mono 1D
140
+ return wav_hat.squeeze(0).squeeze(0).detach()
141
+
142
+
143
+ def load_model_pair(spec: ModelPairSpec, device: str):
144
+ wavvae = WavVAE.from_pretrained_local(spec.wavvae_dir).to(device)
145
+ zflowae = ZFlowAutoEncoder.from_pretrained_local(spec.zflowae_dir).to(device)
146
+ # try to get sampling rate from WavVAE
147
+ target_sr = getattr(wavvae, "sampling_rate", None)
148
+ if target_sr is None:
149
+ # reasonable fallback
150
+ target_sr = 24000
151
+ return wavvae, zflowae, int(target_sr)
152
+
153
+
154
+ def parse_manifest(path: str) -> List[ModelPairSpec]:
155
+ """
156
+ Manifest format (JSON list):
157
+ [
158
+ {
159
+ "name": "zdim32x8",
160
+ "wavvae": "/path/to/WavVAE_framerate100_zdim32/",
161
+ "zflowae": "/path/to/ZFlowAutoEncoder_stride4_zdim32_vae8_.../",
162
+ "decode": {"num_steps": 10, "cfg": 2.0}
163
+ }
164
+ ]
165
+ """
166
+ with open(path, "r", encoding="utf-8") as f:
167
+ raw = json.load(f)
168
+ out: List[ModelPairSpec] = []
169
+ for item in raw:
170
+ name = item["name"]
171
+ wavvae_dir = item["wavvae"]
172
+ zflowae_dir = item["zflowae"]
173
+ d = item.get("decode", {})
174
+ out.append(
175
+ ModelPairSpec(
176
+ name=name,
177
+ wavvae_dir=wavvae_dir,
178
+ zflowae_dir=zflowae_dir,
179
+ decode=DecodeParams(
180
+ num_steps=int(d.get("num_steps", 10)),
181
+ cfg=float(d.get("cfg", 2.0)),
182
+ ),
183
+ )
184
+ )
185
+ return out
186
+
187
+
188
+ # -------------------------
189
+ # HTML generation
190
+ # -------------------------
191
+
192
+
193
+ def html_escape(s: str) -> str:
194
+ return (
195
+ s.replace("&", "&amp;")
196
+ .replace("<", "&lt;")
197
+ .replace(">", "&gt;")
198
+ .replace('"', "&quot;")
199
+ .replace("'", "&#39;")
200
+ )
201
+
202
+
203
+ def make_html(
204
+ output_dir: Path,
205
+ audio_files: List[Path],
206
+ models: List[ModelPairSpec],
207
+ sr_by_model: Dict[str, int],
208
+ wavvae_cfg: Dict[str, Dict[str, Any]],
209
+ zflow_cfg: Dict[str, Dict[str, Any]],
210
+ ) -> str:
211
+ """
212
+ Build a static HTML page with a table:
213
+ Row = input audio file
214
+ Col 1 = Original
215
+ Col 2..N = each model reconstruction
216
+ Also shows minimal model config info above the table.
217
+ """
218
+
219
+ def player(src_rel: str, controls: bool = True) -> str:
220
+ return f'<audio {"controls" if controls else ""} preload="none" src="{html_escape(src_rel)}"></audio>'
221
+
222
+ # Model cards
223
+ model_cards = []
224
+ for spec in models:
225
+ wcfg = wavvae_cfg.get(spec.name, {})
226
+ zcfg = zflow_cfg.get(spec.name, {})
227
+ w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[
228
+ :1200
229
+ ]
230
+ z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[
231
+ :1200
232
+ ]
233
+ card = f"""
234
+ <div class="model-card">
235
+ <h3>{html_escape(spec.name)}</h3>
236
+ <p><b>Sample rate</b>: {sr_by_model.get(spec.name, "N/A")} Hz</p>
237
+ <details>
238
+ <summary>WavVAE config</summary>
239
+ <pre>{html_escape(w_short)}</pre>
240
+ </details>
241
+ <details>
242
+ <summary>ZFlowAE config</summary>
243
+ <pre>{html_escape(z_short)}</pre>
244
+ </details>
245
+ <p><b>Decode</b>: num_steps={spec.decode.num_steps}, cfg={spec.decode.cfg}</p>
246
+ </div>
247
+ """
248
+ model_cards.append(card)
249
+
250
+ # Table header
251
+ th = "<th>Input</th><th>Original</th>" + "".join(
252
+ f"<th>{html_escape(m.name)}</th>" for m in models
253
+ )
254
+
255
+ # Rows
256
+ rows = []
257
+ for af in audio_files:
258
+ base = af.stem
259
+ orig_rel = f"original/{html_escape(af.name)}"
260
+ tds = [f"<td>{html_escape(base)}</td>", f"<td>{player(orig_rel)}</td>"]
261
+ for m in models:
262
+ rec_rel = f"recon/{html_escape(m.name)}/{html_escape(base)}.wav"
263
+ tds.append(f"<td>{player(rec_rel)}</td>")
264
+ rows.append("<tr>" + "".join(tds) + "</tr>")
265
+
266
+ # Simple CSS to keep it clean
267
+ css = """
268
+ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; }
269
+ h1 { margin-bottom: 0.2rem; }
270
+ .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; }
271
+ .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; }
272
+ table { border-collapse: collapse; width: 100%; }
273
+ th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; }
274
+ th { background: #fafafa; position: sticky; top: 0; }
275
+ audio { width: 260px; }
276
+ """
277
+
278
+ html = f"""<!doctype html>
279
+ <html>
280
+ <head>
281
+ <meta charset="utf-8"/>
282
+ <title>Codec Comparison</title>
283
+ <style>{css}</style>
284
+ </head>
285
+ <body>
286
+ <h1>Codec Comparison</h1>
287
+ <p>This page compares reconstructions across model checkpoints. Click play in each cell.</p>
288
+
289
+ <h2>Models</h2>
290
+ <div class="cards">
291
+ {"".join(model_cards)}
292
+ </div>
293
+
294
+ <h2>Audio</h2>
295
+ <table>
296
+ <thead><tr>{th}</tr></thead>
297
+ <tbody>
298
+ {"".join(rows)}
299
+ </tbody>
300
+ </table>
301
+ </body>
302
+ </html>
303
+ """
304
+ out = output_dir / "index.html"
305
+ out.write_text(html, encoding="utf-8")
306
+ return str(out)
307
+
308
+
309
+ # -------------------------
310
+ # Main
311
+ # -------------------------
312
+
313
+
314
+ def main():
315
+ p = argparse.ArgumentParser(
316
+ description="Compare Z-Codec configurations and generate a static HTML page."
317
+ )
318
+ p.add_argument(
319
+ "--manifest",
320
+ type=str,
321
+ required=True,
322
+ help="JSON file listing model pairs. See docstring in parse_manifest().",
323
+ )
324
+ p.add_argument(
325
+ "--audio", type=str, nargs="+", required=True, help="List of input audio files."
326
+ )
327
+ p.add_argument(
328
+ "--out",
329
+ type=str,
330
+ default="codec_compare_out",
331
+ help="Output directory for reconstructions and HTML.",
332
+ )
333
+ p.add_argument(
334
+ "--device",
335
+ type=str,
336
+ default="cuda",
337
+ help="Device to run inference on (cuda or cpu).",
338
+ )
339
+ p.add_argument(
340
+ "--force",
341
+ action="store_true",
342
+ help="Recompute even if target wav already exists.",
343
+ )
344
+ args = p.parse_args()
345
+
346
+ device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"
347
+ out_dir = Path(args.out)
348
+ orig_dir = out_dir / "original"
349
+ recon_dir = out_dir / "recon"
350
+ orig_dir.mkdir(parents=True, exist_ok=True)
351
+ recon_dir.mkdir(parents=True, exist_ok=True)
352
+
353
+ # Parse models
354
+ specs = parse_manifest(args.manifest)
355
+ if not specs:
356
+ print("No models in manifest.", file=sys.stderr)
357
+ sys.exit(1)
358
+
359
+ # Load models
360
+ loaded: Dict[str, Dict[str, Any]] = {}
361
+ sr_by_model: Dict[str, int] = {}
362
+ wavvae_cfg: Dict[str, Dict[str, Any]] = {}
363
+ zflow_cfg: Dict[str, Dict[str, Any]] = {}
364
+
365
+ for spec in specs:
366
+ print(f"[Load] {spec.name}")
367
+ wavvae, zflowae, target_sr = load_model_pair(spec, device)
368
+ loaded[spec.name] = {"wavvae": wavvae, "zflowae": zflowae, "sr": target_sr}
369
+ sr_by_model[spec.name] = target_sr
370
+ wavvae_cfg[spec.name] = read_config_any(spec.wavvae_dir)
371
+ zflow_cfg[spec.name] = read_config_any(spec.zflowae_dir)
372
+
373
+ # Process audio files
374
+ audio_files = [Path(a) for a in args.audio]
375
+ for af in audio_files:
376
+ if not af.exists():
377
+ print(f"[Skip] Missing: {af}", file=sys.stderr)
378
+ continue
379
+
380
+ # copy original (resampled per model? We'll store original as-is)
381
+ # Just place the original file for direct playback
382
+ # If it's not wav, we still copy a WAV version for compatibility.
383
+ # But simplest: if not wav, we re-save as wav 16-bit for the page.
384
+ out_orig = orig_dir / af.name
385
+ if args.force or not out_orig.exists():
386
+ # Load and resave as wav to ensure browser-compat
387
+ wav, sr = ta_load(str(af))
388
+ # make it mono for fair listening
389
+ wav_mono, sr = ensure_mono_and_resample(wav, sr, sr)
390
+ save_wav(out_orig.with_suffix(".wav"), wav_mono, sr)
391
+ # keep the name consistent in the HTML (use .wav)
392
+ af = af.with_suffix(".wav")
393
+ # rename saved file to matched name
394
+ if out_orig.suffix != ".wav":
395
+ # Clean: ensure HTML references the .wav filename
396
+ out_orig = out_orig.with_suffix(".wav")
397
+
398
+ # For each model, run full pipeline and save
399
+ base = af.stem
400
+ # Re-load from disk to ensure consistent start-point (original .wav in out folder)
401
+ wav0, sr0 = ta_load(str(out_orig if out_orig.exists() else orig_dir / af.name))
402
+ # Make mono only once; resample per-model to each target SR
403
+ if wav0.size(0) > 1:
404
+ wav0 = wav0.mean(dim=0, keepdim=True)
405
+
406
+ for spec in specs:
407
+ mname = spec.name
408
+ target_sr = sr_by_model[mname]
409
+ # resample to model's SR
410
+ if sr0 != target_sr:
411
+ wav_mono = ta_resample(wav0, sr0, target_sr)
412
+ else:
413
+ wav_mono = wav0
414
+
415
+ # reconstruct
416
+ out_path = recon_dir / mname / f"{sanitize_name(base)}.wav"
417
+ if args.force or not out_path.exists():
418
+ print(f"[Reconstruct] {mname} ← {base}")
419
+ wavvae = loaded[mname]["wavvae"]
420
+ zflowae = loaded[mname]["zflowae"]
421
+ wav_hat = reconstruct_full_pipeline(
422
+ wav_mono, target_sr, wavvae, zflowae, spec.decode, device
423
+ )
424
+ save_wav(out_path, wav_hat.unsqueeze(0), target_sr)
425
+
426
+ # Build HTML
427
+ # Rebuild the list of files actually present in original/ (use .wav names)
428
+ actual_audio = sorted([p for p in (orig_dir).glob("*.wav")])
429
+ html_path = make_html(
430
+ out_dir,
431
+ actual_audio,
432
+ specs,
433
+ sr_by_model,
434
+ wavvae_cfg,
435
+ zflow_cfg,
436
+ )
437
+ print(f"\nDone. Open: {html_path}")
438
+
439
+
440
+ if __name__ == "__main__":
441
+ main()
codec/scripts/compare_wavvae.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import json
4
+ import sys
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import torch
10
+ import torchaudio
11
+ from torchaudio import load as ta_load
12
+ from torchaudio.functional import resample as ta_resample
13
+ from zcodec.models import WavVAE
14
+
15
+ # -------------------------
16
+ # Data structures
17
+ # -------------------------
18
+
19
+
20
+ @dataclass
21
+ class WavVaeSpec:
22
+ name: str
23
+ wavvae_dir: str
24
+
25
+
26
+ # -------------------------
27
+ # Utilities
28
+ # -------------------------
29
+
30
+
31
+ def load_json_if_exists(path: Path) -> Optional[Dict[str, Any]]:
32
+ if path.is_file():
33
+ try:
34
+ return json.load(path.open("r", encoding="utf-8"))
35
+ except Exception:
36
+ return None
37
+ return None
38
+
39
+
40
+ def read_config_any(checkpoint_dir: str) -> Dict[str, Any]:
41
+ cand = [
42
+ Path(checkpoint_dir) / "config.json",
43
+ Path(checkpoint_dir) / "model_config.json",
44
+ Path(checkpoint_dir) / "config.yaml", # shown as path only
45
+ ]
46
+ for p in cand:
47
+ if p.exists():
48
+ if p.suffix == ".json":
49
+ j = load_json_if_exists(p)
50
+ if j is not None:
51
+ return j
52
+ else:
53
+ return {"_config_file": str(p)}
54
+ return {}
55
+
56
+
57
+ def sanitize_name(s: str) -> str:
58
+ return "".join(c if c.isalnum() or c in "-_." else "_" for c in s)
59
+
60
+
61
+ def ensure_mono_and_resample(
62
+ wav: torch.Tensor, sr: int, target_sr: int
63
+ ) -> Tuple[torch.Tensor, int]:
64
+ if wav.ndim != 2:
65
+ raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}")
66
+ if wav.size(0) > 1:
67
+ wav = wav.mean(dim=0, keepdim=True)
68
+ if sr != target_sr:
69
+ wav = ta_resample(wav, sr, target_sr)
70
+ sr = target_sr
71
+ return wav.to(torch.float32), sr
72
+
73
+
74
+ def save_wav(path: Path, wav: torch.Tensor, sr: int):
75
+ path.parent.mkdir(parents=True, exist_ok=True)
76
+ if wav.ndim == 1:
77
+ wav = wav.unsqueeze(0)
78
+ wav = wav.clamp(-1, 1).contiguous().cpu()
79
+ torchaudio.save(
80
+ str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16
81
+ )
82
+
83
+
84
+ def read_audio_manifest(txt_path: str) -> List[Path]:
85
+ lines = Path(txt_path).read_text(encoding="utf-8").splitlines()
86
+ files = [
87
+ Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#")
88
+ ]
89
+ return files
90
+
91
+
92
+ def html_escape(s: str) -> str:
93
+ return (
94
+ s.replace("&", "&amp;")
95
+ .replace("<", "&lt;")
96
+ .replace(">", "&gt;")
97
+ .replace('"', "&quot;")
98
+ .replace("'", "&#39;")
99
+ )
100
+
101
+
102
+ def make_html(
103
+ output_dir: Path,
104
+ audio_files: List[Path],
105
+ specs: List[WavVaeSpec],
106
+ sr_by_model: Dict[str, int],
107
+ wavvae_cfg: Dict[str, Dict[str, Any]],
108
+ ) -> str:
109
+ def player(src_rel: str) -> str:
110
+ return f'<audio controls preload="none" src="{html_escape(src_rel)}"></audio>'
111
+
112
+ # cards
113
+ cards = []
114
+ for s in specs:
115
+ cfg = wavvae_cfg.get(s.name, {})
116
+ cfg_short = json.dumps(cfg if cfg else {"_": "no JSON config found"}, indent=2)[
117
+ :1200
118
+ ]
119
+ card = f"""
120
+ <div class="model-card">
121
+ <h3>{html_escape(s.name)}</h3>
122
+ <p><b>Sample rate</b>: {sr_by_model.get(s.name, "N/A")} Hz</p>
123
+ <details><summary>WavVAE config</summary><pre>{html_escape(cfg_short)}</pre></details>
124
+ </div>
125
+ """
126
+ cards.append(card)
127
+
128
+ css = """
129
+ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; }
130
+ .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; }
131
+ .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; }
132
+ table { border-collapse: collapse; width: 100%; }
133
+ th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; }
134
+ th { background: #fafafa; position: sticky; top: 0; }
135
+ audio { width: 260px; }
136
+ """
137
+
138
+ th = "<th>Input</th><th>Original</th>" + "".join(
139
+ f"<th>{html_escape(s.name)}</th>" for s in specs
140
+ )
141
+ rows = []
142
+ for af in audio_files:
143
+ base = af.stem
144
+ orig_rel = f"original/{html_escape(af.name)}"
145
+ tds = [f"<td>{html_escape(base)}</td>", f"<td>{player(orig_rel)}</td>"]
146
+ for s in specs:
147
+ rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav"
148
+ tds.append(f"<td>{player(rec_rel)}</td>")
149
+ rows.append("<tr>" + "".join(tds) + "</tr>")
150
+
151
+ html = f"""<!doctype html>
152
+ <html>
153
+ <head><meta charset="utf-8"/><title>WavVAE Comparison</title><style>{css}</style></head>
154
+ <body>
155
+ <h1>WavVAE Comparison</h1>
156
+ <div class="cards">{"".join(cards)}</div>
157
+ <table>
158
+ <thead><tr>{th}</tr></thead>
159
+ <tbody>{"".join(rows)}</tbody>
160
+ </table>
161
+ </body>
162
+ </html>
163
+ """
164
+ out = output_dir / "index.html"
165
+ out.write_text(html, encoding="utf-8")
166
+ return str(out)
167
+
168
+
169
+ # -------------------------
170
+ # Core
171
+ # -------------------------
172
+
173
+
174
+ @torch.inference_mode()
175
+ def reconstruct_wavvae(
176
+ wav_mono: torch.Tensor, wavvae: WavVAE, device: str
177
+ ) -> torch.Tensor:
178
+ x = wav_mono.to(device) # (1,T)
179
+ z = wavvae.encode(x)
180
+ wav_hat = wavvae.decode(z) # (1,1,T)
181
+ return wav_hat.squeeze(0).squeeze(0).detach()
182
+
183
+
184
+ def parse_models_manifest(path: str) -> List[WavVaeSpec]:
185
+ """
186
+ JSON list of:
187
+ {"name": "...", "wavvae": "/path/to/WavVAE_dir"}
188
+ """
189
+ raw = json.loads(Path(path).read_text(encoding="utf-8"))
190
+ specs = []
191
+ for it in raw:
192
+ specs.append(WavVaeSpec(name=it["name"], wavvae_dir=it["wavvae"]))
193
+ return specs
194
+
195
+
196
+ def main():
197
+ ap = argparse.ArgumentParser(
198
+ description="Compare WavVAE checkpoints and generate a static HTML page."
199
+ )
200
+ ap.add_argument("--models", required=True, help="JSON manifest of WavVAE models.")
201
+ ap.add_argument(
202
+ "--audio_manifest", required=True, help="TXT file: one audio path per line."
203
+ )
204
+ ap.add_argument("--out", default="compare_wavvae_out")
205
+ ap.add_argument("--device", default="cuda")
206
+ ap.add_argument("--force", action="store_true")
207
+ args = ap.parse_args()
208
+
209
+ device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"
210
+ out_dir = Path(args.out)
211
+ (out_dir / "original").mkdir(parents=True, exist_ok=True)
212
+ recon_dir = out_dir / "recon"
213
+ recon_dir.mkdir(parents=True, exist_ok=True)
214
+
215
+ specs = parse_models_manifest(args.models)
216
+ if not specs:
217
+ print("No models.", file=sys.stderr)
218
+ sys.exit(1)
219
+
220
+ # load models
221
+ wavvae_by_name: Dict[str, WavVAE] = {}
222
+ sr_by_model: Dict[str, int] = {}
223
+ wavvae_cfg: Dict[str, Dict[str, Any]] = {}
224
+ for s in specs:
225
+ print(f"[Load] {s.name}")
226
+ w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device)
227
+ wavvae_by_name[s.name] = w
228
+ sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000))
229
+ wavvae_cfg[s.name] = read_config_any(s.wavvae_dir)
230
+
231
+ audio_paths = read_audio_manifest(args.audio_manifest)
232
+ # normalize originals to wav+mono (browser-friendly); keep native sr for original column
233
+ actual_audio = []
234
+ for ap in audio_paths:
235
+ if not ap.exists():
236
+ print(f"[Skip missing] {ap}", file=sys.stderr)
237
+ continue
238
+ wav, sr = ta_load(str(ap))
239
+ wav_mono, sr = ensure_mono_and_resample(wav, sr, sr)
240
+ out_orig = out_dir / "original" / (ap.stem + ".wav")
241
+ if args.force or not out_orig.exists():
242
+ save_wav(out_orig, wav_mono, sr)
243
+ actual_audio.append(out_orig)
244
+
245
+ # recon per model
246
+ for out_orig in actual_audio:
247
+ wav0, sr0 = ta_load(str(out_orig))
248
+ if wav0.size(0) > 1:
249
+ wav0 = wav0.mean(dim=0, keepdim=True)
250
+ for s in specs:
251
+ target_sr = sr_by_model[s.name]
252
+ wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0
253
+ out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav"
254
+ if args.force or not out_path.exists():
255
+ print(f"[Reconstruct] {s.name} ← {out_orig.name}")
256
+ wav_hat = reconstruct_wavvae(wav_in, wavvae_by_name[s.name], device)
257
+ save_wav(out_path, wav_hat, target_sr)
258
+
259
+ html_path = make_html(out_dir, actual_audio, specs, sr_by_model, wavvae_cfg)
260
+ print(f"Done. Open: {html_path}")
261
+
262
+
263
+ if __name__ == "__main__":
264
+ main()
codec/scripts/compare_zcodec.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import json
4
+ import sys
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import torch
10
+ import torchaudio
11
+ from torchaudio import load as ta_load
12
+ from torchaudio.functional import resample as ta_resample
13
+ from zcodec.models import WavVAE, ZFlowAutoEncoder
14
+
15
+ # -------------------------
16
+ # Data structures
17
+ # -------------------------
18
+
19
+
20
+ @dataclass
21
+ class DecodeParams:
22
+ num_steps: int = 10
23
+ cfg: float = 2.0
24
+
25
+
26
+ @dataclass
27
+ class StackSpec:
28
+ name: str
29
+ wavvae_dir: str
30
+ zflowae_dir: str
31
+ decode: DecodeParams
32
+
33
+
34
+ # -------------------------
35
+ # Utilities (same helpers)
36
+ # -------------------------
37
+
38
+
39
+ def load_json_if_exists(path: Path):
40
+ if path.is_file():
41
+ try:
42
+ return json.load(path.open("r", encoding="utf-8"))
43
+ except Exception:
44
+ return None
45
+ return None
46
+
47
+
48
+ def read_config_any(checkpoint_dir: str) -> Dict[str, Any]:
49
+ cand = [
50
+ Path(checkpoint_dir) / "config.json",
51
+ Path(checkpoint_dir) / "model_config.json",
52
+ Path(checkpoint_dir) / "config.yaml",
53
+ ]
54
+ for p in cand:
55
+ if p.exists():
56
+ if p.suffix == ".json":
57
+ j = load_json_if_exists(p)
58
+ if j is not None:
59
+ return j
60
+ else:
61
+ return {"_config_file": str(p)}
62
+ return {}
63
+
64
+
65
+ def sanitize_name(s: str) -> str:
66
+ return "".join(c if c.isalnum() or c in "-_." else "_" for c in s)
67
+
68
+
69
+ def ensure_mono_and_resample(
70
+ wav: torch.Tensor, sr: int, target_sr: int
71
+ ) -> Tuple[torch.Tensor, int]:
72
+ if wav.ndim != 2:
73
+ raise ValueError(f"Expected (C,T), got {tuple(wav.shape)}")
74
+ if wav.size(0) > 1:
75
+ wav = wav.mean(dim=0, keepdim=True)
76
+ if sr != target_sr:
77
+ wav = ta_resample(wav, sr, target_sr)
78
+ sr = target_sr
79
+ return wav.to(torch.float32), sr
80
+
81
+
82
+ def save_wav(path: Path, wav: torch.Tensor, sr: int):
83
+ path.parent.mkdir(parents=True, exist_ok=True)
84
+ if wav.ndim == 1:
85
+ wav = wav.unsqueeze(0)
86
+ wav = wav.clamp(-1, 1).contiguous().cpu()
87
+ torchaudio.save(
88
+ str(path), wav, sample_rate=sr, encoding="PCM_S", bits_per_sample=16
89
+ )
90
+
91
+
92
+ def read_audio_manifest(txt_path: str) -> List[Path]:
93
+ lines = Path(txt_path).read_text(encoding="utf-8").splitlines()
94
+ return [
95
+ Path(l.strip()) for l in lines if l.strip() and not l.strip().startswith("#")
96
+ ]
97
+
98
+
99
+ def html_escape(s: str) -> str:
100
+ return (
101
+ s.replace("&", "&amp;")
102
+ .replace("<", "&lt;")
103
+ .replace(">", "&gt;")
104
+ .replace('"', "&quot;")
105
+ .replace("'", "&#39;")
106
+ )
107
+
108
+
109
+ def make_html(
110
+ output_dir: Path,
111
+ audio_files: List[Path],
112
+ specs: List[StackSpec],
113
+ sr_by_model: Dict[str, int],
114
+ wavvae_cfg: Dict[str, Dict[str, Any]],
115
+ zflow_cfg: Dict[str, Dict[str, Any]],
116
+ ) -> str:
117
+ def player(src_rel: str) -> str:
118
+ return f'<audio controls preload="none" src="{html_escape(src_rel)}"></audio>'
119
+
120
+ cards = []
121
+ for s in specs:
122
+ wcfg = wavvae_cfg.get(s.name, {})
123
+ zcfg = zflow_cfg.get(s.name, {})
124
+ w_short = json.dumps(wcfg if wcfg else {"_": "no JSON config found"}, indent=2)[
125
+ :1200
126
+ ]
127
+ z_short = json.dumps(zcfg if zcfg else {"_": "no JSON config found"}, indent=2)[
128
+ :1200
129
+ ]
130
+ card = f"""
131
+ <div class="model-card">
132
+ <h3>{html_escape(s.name)}</h3>
133
+ <p><b>Sample rate</b>: {sr_by_model.get(s.name, "N/A")} Hz</p>
134
+ <p><b>Decode</b>: steps={s.decode.num_steps}, cfg={s.decode.cfg}</p>
135
+ <details><summary>WavVAE config</summary><pre>{html_escape(w_short)}</pre></details>
136
+ <details><summary>ZFlowAE config</summary><pre>{html_escape(z_short)}</pre></details>
137
+ </div>
138
+ """
139
+ cards.append(card)
140
+
141
+ css = """
142
+ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif; padding: 20px; }
143
+ .cards { display: grid; grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); gap: 12px; margin-bottom: 18px; }
144
+ .model-card { border: 1px solid #ddd; border-radius: 12px; padding: 12px; }
145
+ table { border-collapse: collapse; width: 100%; }
146
+ th, td { border: 1px solid #eee; padding: 8px; vertical-align: top; }
147
+ th { background: #fafafa; position: sticky; top: 0; }
148
+ audio { width: 260px; }
149
+ """
150
+
151
+ th = "<th>Input</th><th>Original</th>" + "".join(
152
+ f"<th>{html_escape(s.name)}</th>" for s in specs
153
+ )
154
+ rows = []
155
+ for af in audio_files:
156
+ base = af.stem
157
+ orig_rel = f"original/{html_escape(af.name)}"
158
+ tds = [f"<td>{html_escape(base)}</td>", f"<td>{player(orig_rel)}</td>"]
159
+ for s in specs:
160
+ rec_rel = f"recon/{html_escape(s.name)}/{html_escape(base)}.wav"
161
+ tds.append(f"<td>{player(rec_rel)}</td>")
162
+ rows.append("<tr>" + "".join(tds) + "</tr>")
163
+
164
+ html = f"""<!doctype html>
165
+ <html>
166
+ <head><meta charset="utf-8"/><title>Stacked Codec Comparison</title><style>{css}</style></head>
167
+ <body>
168
+ <h1>WavVAE + ZFlowAE Comparison</h1>
169
+ <div class="cards">{"".join(cards)}</div>
170
+ <table>
171
+ <thead><tr>{th}</tr></thead>
172
+ <tbody>{"".join(rows)}</tbody>
173
+ </table>
174
+ </body>
175
+ </html>
176
+ """
177
+ out = output_dir / "index.html"
178
+ out.write_text(html, encoding="utf-8")
179
+ return str(out)
180
+
181
+
182
+ # -------------------------
183
+ # Core
184
+ # -------------------------
185
+
186
+
187
+ @torch.inference_mode()
188
+ def reconstruct_stack(
189
+ wav_mono: torch.Tensor,
190
+ wavvae: WavVAE,
191
+ zflow: ZFlowAutoEncoder,
192
+ steps: int,
193
+ cfg: float,
194
+ device: str,
195
+ ) -> torch.Tensor:
196
+ x = wav_mono.to(device) # (1,T)
197
+ z = wavvae.encode(x) # high-framerate latents
198
+ y, _ = zflow.encode(z) # compressed latents
199
+ z_hat = zflow.decode(y, num_steps=steps, cfg=cfg)
200
+ wav_hat = wavvae.decode(z_hat) # (1,1,T)
201
+ return wav_hat.squeeze(0).squeeze(0).detach()
202
+
203
+
204
+ def parse_models_manifest(path: str) -> List[StackSpec]:
205
+ """
206
+ JSON list of:
207
+ {
208
+ "name": "...",
209
+ "wavvae": "/path/to/WavVAE_dir",
210
+ "zflowae": "/path/to/ZFlowAE_dir",
211
+ "decode": {"num_steps": 10, "cfg": 2.0}
212
+ }
213
+ """
214
+ raw = json.loads(Path(path).read_text(encoding="utf-8"))
215
+ specs = []
216
+ for it in raw:
217
+ d = it.get("decode", {})
218
+ specs.append(
219
+ StackSpec(
220
+ name=it["name"],
221
+ wavvae_dir=it["wavvae"],
222
+ zflowae_dir=it["zflowae"],
223
+ decode=DecodeParams(
224
+ num_steps=int(d.get("num_steps", 10)), cfg=float(d.get("cfg", 2.0))
225
+ ),
226
+ )
227
+ )
228
+ return specs
229
+
230
+
231
+ def main():
232
+ ap = argparse.ArgumentParser(
233
+ description="Compare WavVAE+ZFlowAE stacks and generate a static HTML page."
234
+ )
235
+ ap.add_argument("--models", required=True, help="JSON manifest of stacks.")
236
+ ap.add_argument(
237
+ "--audio_manifest", required=True, help="TXT file: one audio path per line."
238
+ )
239
+ ap.add_argument("--out", default="compare_stack_out")
240
+ ap.add_argument("--device", default="cuda")
241
+ ap.add_argument("--force", action="store_true")
242
+ args = ap.parse_args()
243
+
244
+ device = "cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu"
245
+ out_dir = Path(args.out)
246
+ (out_dir / "original").mkdir(parents=True, exist_ok=True)
247
+ recon_dir = out_dir / "recon"
248
+ recon_dir.mkdir(parents=True, exist_ok=True)
249
+
250
+ specs = parse_models_manifest(args.models)
251
+ if not specs:
252
+ print("No models.", file=sys.stderr)
253
+ sys.exit(1)
254
+
255
+ # load models
256
+ wavvae_by_name: Dict[str, WavVAE] = {}
257
+ zflow_by_name: Dict[str, ZFlowAutoEncoder] = {}
258
+ sr_by_model: Dict[str, int] = {}
259
+ wavvae_cfg: Dict[str, Dict[str, Any]] = {}
260
+ zflow_cfg: Dict[str, Dict[str, Any]] = {}
261
+ for s in specs:
262
+ print(f"[Load] {s.name}")
263
+ w = WavVAE.from_pretrained_local(s.wavvae_dir).to(device)
264
+ z = ZFlowAutoEncoder.from_pretrained_local(s.zflowae_dir).to(device)
265
+ wavvae_by_name[s.name] = w
266
+ zflow_by_name[s.name] = z
267
+ sr_by_model[s.name] = int(getattr(w, "sampling_rate", 24000))
268
+ wavvae_cfg[s.name] = read_config_any(s.wavvae_dir)
269
+ zflow_cfg[s.name] = read_config_any(s.zflowae_dir)
270
+
271
+ audio_paths = read_audio_manifest(args.audio_manifest)
272
+
273
+ actual_audio = []
274
+ for ap in audio_paths:
275
+ if not ap.exists():
276
+ print(f"[Skip missing] {ap}", file=sys.stderr)
277
+ continue
278
+ wav, sr = ta_load(str(ap))
279
+ wav_mono, sr = ensure_mono_and_resample(wav, sr, sr)
280
+ out_orig = out_dir / "original" / (ap.stem + ".wav")
281
+ if args.force or not out_orig.exists():
282
+ save_wav(out_orig, wav_mono, sr)
283
+ actual_audio.append(out_orig)
284
+
285
+ for out_orig in actual_audio:
286
+ wav0, sr0 = ta_load(str(out_orig))
287
+ if wav0.size(0) > 1:
288
+ wav0 = wav0.mean(dim=0, keepdim=True)
289
+ for s in specs:
290
+ target_sr = sr_by_model[s.name]
291
+ wav_in = ta_resample(wav0, sr0, target_sr) if sr0 != target_sr else wav0
292
+ out_path = recon_dir / s.name / f"{sanitize_name(out_orig.stem)}.wav"
293
+ if args.force or not out_path.exists():
294
+ print(f"[Reconstruct] {s.name} ← {out_orig.name}")
295
+ wav_hat = reconstruct_stack(
296
+ wav_in,
297
+ wavvae_by_name[s.name],
298
+ zflow_by_name[s.name],
299
+ s.decode.num_steps,
300
+ s.decode.cfg,
301
+ device,
302
+ )
303
+ save_wav(out_path, wav_hat, target_sr)
304
+
305
+ html_path = make_html(
306
+ out_dir, actual_audio, specs, sr_by_model, wavvae_cfg, zflow_cfg
307
+ )
308
+ print(f"Done. Open: {html_path}")
309
+
310
+
311
+ if __name__ == "__main__":
312
+ main()
codec/scripts/compute_stats.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+
4
+ import torch
5
+ from safetensors.torch import safe_open, save_file
6
+ from tqdm import tqdm
7
+
8
+
9
+ def load_tensor(path: str, key: str = "embedding") -> torch.Tensor:
10
+ with safe_open(path, framework="pt", device="cpu") as f:
11
+ return f.get_tensor(key)
12
+
13
+
14
+ def compute_global_stats(file_list, key="embedding", length_weighted=True):
15
+ sum_all = None
16
+ sum_sq_all = None
17
+ count_all = 0
18
+
19
+ for path in tqdm(file_list, desc="Computing stats"):
20
+ tensor = load_tensor(path, key) # shape: [B, T, D]
21
+ flat = tensor.reshape(-1, tensor.shape[-1]) # [B*T, D]
22
+
23
+ sum_ = flat.sum(dim=0) # [D]
24
+ sum_sq = (flat**2).sum(dim=0) # [D]
25
+ count = flat.shape[0] # B*T
26
+
27
+ if sum_all is None:
28
+ sum_all = sum_
29
+ sum_sq_all = sum_sq
30
+ else:
31
+ sum_all += sum_
32
+ sum_sq_all += sum_sq
33
+
34
+ count_all += count
35
+
36
+ mean = sum_all / count_all
37
+ var = sum_sq_all / count_all - mean**2
38
+ std = torch.sqrt(torch.clamp(var, min=1e-8))
39
+
40
+ return mean, std
41
+
42
+
43
+ def main():
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument(
46
+ "filelist", type=str, help="Text file with list of safetensors paths"
47
+ )
48
+ parser.add_argument("output", type=str, help="Path to output stats.safetensors")
49
+ parser.add_argument(
50
+ "--key", type=str, default="audio_z", help="Key of tensor in safetensors file"
51
+ )
52
+ parser.add_argument(
53
+ "--max-files", type=int, default=None, help="Max number of files to process"
54
+ )
55
+ parser.add_argument(
56
+ "--seed", type=int, default=42, help="Random seed for shuffling"
57
+ )
58
+
59
+ args = parser.parse_args()
60
+
61
+ with open(args.filelist) as f:
62
+ files = [line.strip() for line in f if line.strip()]
63
+
64
+ if args.max_files:
65
+ random.seed(args.seed)
66
+ files = random.sample(files, k=min(args.max_files, len(files)))
67
+
68
+ mean, std = compute_global_stats(files, key=args.key)
69
+
70
+ save_file({"mean": mean, "std": std}, args.output)
71
+ print(f"✅ Saved to {args.output}")
72
+ print("Example mean/std:", mean[:5], std[:5])
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
codec/scripts/compute_wer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import string
4
+
5
+ from jiwer import wer
6
+
7
+
8
+ def normalize_text(text: str) -> str:
9
+ """
10
+ Lowercase and remove punctuation from a string.
11
+
12
+ Args:
13
+ text (str): Input string
14
+
15
+ Returns:
16
+ str: Normalized string
17
+ """
18
+ # Lowercase
19
+ text = text.lower()
20
+ # Remove punctuation
21
+ text = text.translate(str.maketrans("", "", string.punctuation))
22
+ return text
23
+
24
+
25
+ def load_transcripts(jsonl_path):
26
+ originals = []
27
+ reconstructions = []
28
+ with open(jsonl_path, "r", encoding="utf-8") as f:
29
+ for line in f:
30
+ data = json.loads(line)
31
+ originals.append(data["original_text"])
32
+ reconstructions.append(data["reconstructed_text"])
33
+ return originals, reconstructions
34
+
35
+
36
+ def main(args):
37
+ originals, reconstructions = map(normalize_text, load_transcripts(args.jsonl))
38
+ score = wer(originals, reconstructions)
39
+ print(f"WER: {score:.3%}")
40
+
41
+
42
+ if __name__ == "__main__":
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument(
45
+ "--jsonl", type=str, required=True, help="Path to the transcript JSONL file"
46
+ )
47
+ args = parser.parse_args()
48
+ main(args)
codec/scripts/compute_wer_from_refs.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import string
4
+ from pathlib import Path
5
+
6
+ from jiwer import cer, wer
7
+
8
+
9
+ def normalize_text(text: str) -> str:
10
+ """
11
+ Lowercase and remove punctuation from a string.
12
+
13
+ Args:
14
+ text (str): Input string
15
+
16
+ Returns:
17
+ str: Normalized string
18
+ """
19
+ # Lowercase
20
+ text = text.lower()
21
+ # Remove punctuation
22
+ text = text.translate(str.maketrans("", "", string.punctuation))
23
+ return text
24
+
25
+
26
+ def load_jsonl_dict(path):
27
+ transcripts = {}
28
+ with open(path, "r", encoding="utf-8") as f:
29
+ for line in f:
30
+ data = json.loads(line)
31
+ transcripts[Path(data["file"]).name] = data["transcript"]
32
+ return transcripts
33
+
34
+
35
+ def main(args):
36
+ ref_dict = load_jsonl_dict(args.reference)
37
+ hyp_dict = load_jsonl_dict(args.hypothesis)
38
+
39
+ common_files = set(ref_dict.keys()) & set(hyp_dict.keys())
40
+
41
+ if not common_files:
42
+ print("No common files between reference and hypothesis.")
43
+ return
44
+
45
+ refs = [normalize_text(ref_dict[f]) for f in sorted(common_files)]
46
+ hyps = [normalize_text(hyp_dict[f]) for f in sorted(common_files)]
47
+
48
+ cer_score = cer(refs, hyps)
49
+ wer_score = wer(refs, hyps)
50
+ print(f"CER: {cer_score:.3%}")
51
+ print(f"WER: {wer_score:.3%}")
52
+ print(f"Evaluated on {len(common_files)} files.")
53
+
54
+
55
+ if __name__ == "__main__":
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument(
58
+ "--reference", type=str, required=True, help="Path to reference JSONL"
59
+ )
60
+ parser.add_argument(
61
+ "--hypothesis", type=str, required=True, help="Path to hypothesis JSONL"
62
+ )
63
+ args = parser.parse_args()
64
+ main(args)
codec/scripts/download_expresso.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+
3
+ from datasets import load_dataset
4
+
5
+ dataset = load_dataset("ylacombe/expresso", split="train")
6
+ print(dataset)
7
+ for i, x in enumerate(dataset):
8
+ audio = x["audio"]
9
+ wav, sr = audio["array"], audio["sampling_rate"]
10
+ sf.write(f"expresso/org/{i}.wav", wav, sr)
codec/scripts/download_gigaspeech.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import sample
2
+
3
+ import soundfile as sf
4
+ from datasets import load_dataset
5
+
6
+ # dataset = load_dataset("keithito/lj_speech", split="train")
7
+ #dataset = load_dataset("parler-tts/mls_eng", split="train")
8
+ dataset = load_dataset("speechcolab/gigaspeech", "xl", split="train", token=True)
9
+ Is = sample(list(range(len(dataset))), k=100000)
10
+ print(dataset)
11
+ for i, I in enumerate(Is):
12
+ audio = dataset[I]["audio"]
13
+ wav, sr = audio["array"], audio["sampling_rate"]
14
+ sf.write(f"gigaspeech/{I}.wav", wav, sr)
codec/scripts/download_lj.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+ from datasets import load_dataset
3
+
4
+ dataset = load_dataset("keithito/lj_speech", split="train")
5
+ print(dataset)
6
+ for i, x in enumerate(dataset):
7
+ audio = x["audio"]
8
+ wav, sr = audio["array"], audio["sampling_rate"]
9
+ sf.write(f"ljspeech/{i}.wav", wav, sr)
codec/scripts/download_ltts.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import soundfile as sf
4
+
5
+ from datasets import load_dataset
6
+
7
+ dataset = load_dataset("mythicinfinity/libritts", "clean")
8
+ for split in dataset.keys():
9
+ Path(f"libritts/{split}").mkdir(exist_ok=True)
10
+ for i, x in enumerate(dataset[split]):
11
+ # audio = x["audio"]
12
+ text = x["text_normalized"]
13
+ # wav, sr = audio["array"], audio["sampling_rate"]
14
+ # sf.write(f"libritts/{split}/{i}.wav", wav, sr)
15
+ with open(f"libritts/{split}/{i}.txt", "w") as f:
16
+ f.write(text)
codec/scripts/download_mlseng10k.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import sample
2
+
3
+ import soundfile as sf
4
+ from datasets import load_dataset
5
+
6
+ # dataset = load_dataset("keithito/lj_speech", split="train")
7
+ dataset = load_dataset("parler-tts/mls_eng", split="train")
8
+ Is = sample(list(range(len(dataset))), k=100000)
9
+ print(dataset)
10
+ for i, I in enumerate(Is):
11
+ audio = dataset[I]["audio"]
12
+ wav, sr = audio["array"], audio["sampling_rate"]
13
+ sf.write(f"mls10keng/{i}.wav", wav, sr)
codec/scripts/eval_asr.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import nemo.collections.asr as nemo_asr
6
+ import torch
7
+ import yaml
8
+ from jiwer import wer
9
+ from torchaudio import load
10
+ from torchaudio.functional import resample
11
+ from tqdm import tqdm
12
+
13
+ from zcodec.models import WavVAE, ZFlowAutoEncoder
14
+
15
+
16
+ def load_config(config_path):
17
+ with open(config_path, "r") as f:
18
+ return yaml.safe_load(f)
19
+
20
+
21
+ def transcribe(audio: torch.Tensor, asr_model) -> str:
22
+ audio = audio.cpu().numpy(force=True)
23
+ with torch.inference_mode():
24
+ return asr_model.transcribe([audio[0]])[0].text
25
+
26
+
27
+ def main(args):
28
+ config = load_config(args.config)
29
+
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+ # Load models
33
+ wavvae = WavVAE.from_pretrained_local(config["wavvae_ckpt"]).to(device).eval()
34
+ zflowae = (
35
+ ZFlowAutoEncoder.from_pretrained_local(config["zflowae_ckpt"]).to(device).eval()
36
+ )
37
+
38
+ # Load ASR model
39
+ asr_model = nemo_asr.models.ASRModel.from_pretrained(
40
+ model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2")
41
+ )
42
+
43
+ # Read file list
44
+ with open(config["file_list"], "r") as f:
45
+ wav_files = [line.strip() for line in f if line.strip()]
46
+
47
+ results = []
48
+
49
+ for wav_path in tqdm(wav_files, desc="Processing files"):
50
+ wav, sr = load(wav_path)
51
+ wav = resample(wav, orig_freq=sr, new_freq=wavvae.sampling_rate).to(device)
52
+
53
+ with torch.inference_mode():
54
+ # Transcribe original
55
+ original_text = transcribe(wav, asr_model)
56
+
57
+ # Compress and decompress
58
+ z = wavvae.encode(wav)
59
+ zz, _ = zflowae.encode(z)
60
+ z_hat = zflowae.decode(
61
+ zz, num_steps=config.get("num_steps", 10), cfg=config.get("cfg", 2.0)
62
+ )
63
+ wav_hat = wavvae.decode(z_hat)
64
+
65
+ # Transcribe reconstructed
66
+ reconstructed_text = transcribe(wav_hat, asr_model)
67
+
68
+ results.append(
69
+ {
70
+ "file": wav_path,
71
+ "original_text": original_text,
72
+ "reconstructed_text": reconstructed_text,
73
+ }
74
+ )
75
+
76
+ # Save output
77
+ out_path = Path(config.get("output_jsonl", "transcripts.jsonl"))
78
+ with out_path.open("w") as f:
79
+ for entry in results:
80
+ f.write(json.dumps(entry, ensure_ascii=False) + "\n")
81
+
82
+ print(f"\nSaved {len(results)} transcript pairs to {out_path}")
83
+
84
+ # Optionally compute WER
85
+ if args.compute_wer:
86
+ original_texts = [r["original_text"] for r in results]
87
+ reconstructed_texts = [r["reconstructed_text"] for r in results]
88
+ score = wer(original_texts, reconstructed_texts)
89
+ print(f"WER: {score:.3%}")
90
+
91
+
92
+ if __name__ == "__main__":
93
+ parser = argparse.ArgumentParser()
94
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
95
+ parser.add_argument(
96
+ "--compute_wer", action="store_true", help="Compute WER after decoding"
97
+ )
98
+ args = parser.parse_args()
99
+
100
+ main(args)
codec/scripts/eval_asr_from_filelist.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import nemo.collections.asr as nemo_asr
6
+ import torch
7
+ import yaml
8
+ from torchaudio import load
9
+ from torchaudio.functional import resample
10
+ from tqdm import tqdm
11
+
12
+
13
+ def load_config(config_path):
14
+ with open(config_path, "r") as f:
15
+ return yaml.safe_load(f)
16
+
17
+
18
+ def transcribe(audio: torch.Tensor, asr_model) -> str:
19
+ audio = audio.cpu().numpy(force=True)
20
+ with torch.inference_mode():
21
+ return asr_model.transcribe([audio[0]])[0].text
22
+
23
+
24
+ def main(args):
25
+ config = load_config(args.config)
26
+
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ # Load ASR model
30
+ asr_model = nemo_asr.models.ASRModel.from_pretrained(
31
+ model_name=config.get("asr_model", "nvidia/parakeet-tdt-0.6b-v2")
32
+ )
33
+
34
+ # Read file list
35
+ with open(config["file_list"], "r") as f:
36
+ wav_files = [line.strip() for line in f if line.strip()]
37
+
38
+ results = []
39
+
40
+ for wav_path in tqdm(wav_files, desc="Transcribing"):
41
+ wav, sr = load(wav_path)
42
+ wav = resample(wav, orig_freq=sr, new_freq=16000).to(device)
43
+
44
+ transcript = transcribe(wav, asr_model)
45
+ results.append({"file": wav_path, "transcript": transcript})
46
+
47
+ # Save output
48
+ out_path = Path(config.get("output_jsonl", "asr_transcripts.jsonl"))
49
+ with out_path.open("w") as f:
50
+ for entry in results:
51
+ f.write(json.dumps(entry, ensure_ascii=False) + "\n")
52
+
53
+ print(f"\nSaved {len(results)} transcripts to {out_path}")
54
+
55
+
56
+ if __name__ == "__main__":
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
59
+ args = parser.parse_args()
60
+ main(args)