Rthur2003 commited on
Commit
7dc37a8
·
1 Parent(s): 337d9ae

feat: add wav2vec2 classifier for AI music detection with training loop and dataset handling

Browse files
Files changed (1) hide show
  1. app/training/wav2vec2_classifier.py +324 -0
app/training/wav2vec2_classifier.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wav2vec2 fine-tuned classifier for AI music detection.
3
+
4
+ Architecture:
5
+ wav2vec2-base (frozen CNN encoder, trainable transformer)
6
+ → Global Average Pooling (768-dim)
7
+ → Linear(768, 256) + ReLU + Dropout(0.3)
8
+ → Linear(256, 1)
9
+ → Binary: AI (1) vs Human (0)
10
+
11
+ This module defines both the model and the training loop.
12
+ Requires GPU for training (~2-4 hours on T4 for 10K samples).
13
+ CPU inference: ~0.5s for 30s audio.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import sys
19
+ from dataclasses import dataclass
20
+ from pathlib import Path
21
+ from typing import Optional
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ from torch.utils.data import Dataset, DataLoader
27
+
28
+
29
+ @dataclass
30
+ class Wav2Vec2Config:
31
+ """Training configuration."""
32
+
33
+ model_name: str = "facebook/wav2vec2-base"
34
+ max_audio_sec: float = 30.0
35
+ sample_rate: int = 16000
36
+ hidden_dim: int = 256
37
+ dropout: float = 0.3
38
+ learning_rate_head: float = 1e-3
39
+ learning_rate_encoder: float = 1e-5
40
+ weight_decay: float = 0.01
41
+ batch_size: int = 8
42
+ epochs: int = 10
43
+ patience: int = 3
44
+ device: str = "auto"
45
+
46
+
47
+ class Wav2Vec2MusicClassifier(nn.Module):
48
+ """
49
+ wav2vec2 with classification head for AI music detection.
50
+
51
+ The CNN feature encoder is frozen (robust low-level audio
52
+ representation). The transformer layers are fine-tuned to
53
+ learn task-specific temporal patterns.
54
+ """
55
+
56
+ def __init__(self, config: Wav2Vec2Config | None = None):
57
+ super().__init__()
58
+ self.config = config or Wav2Vec2Config()
59
+
60
+ from transformers import Wav2Vec2Model
61
+
62
+ self.wav2vec2 = Wav2Vec2Model.from_pretrained(
63
+ self.config.model_name
64
+ )
65
+ # Freeze CNN encoder, fine-tune transformer
66
+ self.wav2vec2.feature_extractor._freeze_parameters()
67
+
68
+ self.classifier = nn.Sequential(
69
+ nn.Linear(768, self.config.hidden_dim),
70
+ nn.ReLU(),
71
+ nn.Dropout(self.config.dropout),
72
+ nn.Linear(self.config.hidden_dim, 1),
73
+ )
74
+
75
+ def forward(
76
+ self, input_values: torch.Tensor
77
+ ) -> tuple[torch.Tensor, torch.Tensor]:
78
+ """
79
+ Forward pass.
80
+
81
+ Args:
82
+ input_values: (batch, samples) raw audio waveform at 16kHz.
83
+
84
+ Returns:
85
+ logits: (batch, 1) classification logit.
86
+ hidden: (batch, 768) pooled hidden states for meta-classifier.
87
+ """
88
+ outputs = self.wav2vec2(input_values)
89
+ # Mean pool over time dimension
90
+ hidden = outputs.last_hidden_state.mean(dim=1) # (batch, 768)
91
+ logits = self.classifier(hidden) # (batch, 1)
92
+ return logits, hidden
93
+
94
+ def predict_proba(self, input_values: torch.Tensor) -> np.ndarray:
95
+ """Get probability of AI-generated class."""
96
+ self.eval()
97
+ with torch.no_grad():
98
+ logits, _ = self(input_values)
99
+ probs = torch.sigmoid(logits).cpu().numpy().flatten()
100
+ return probs
101
+
102
+
103
+ class AudioDataset(Dataset):
104
+ """Simple dataset that loads audio files and labels."""
105
+
106
+ def __init__(
107
+ self,
108
+ file_paths: list[str],
109
+ labels: list[int],
110
+ sample_rate: int = 16000,
111
+ max_sec: float = 30.0,
112
+ ):
113
+ self.file_paths = file_paths
114
+ self.labels = labels
115
+ self.sample_rate = sample_rate
116
+ self.max_samples = int(max_sec * sample_rate)
117
+
118
+ def __len__(self):
119
+ return len(self.file_paths)
120
+
121
+ def __getitem__(self, idx):
122
+ import librosa
123
+
124
+ path = self.file_paths[idx]
125
+ label = self.labels[idx]
126
+
127
+ # Load at 16kHz for wav2vec2
128
+ y, _ = librosa.load(path, sr=self.sample_rate, mono=True)
129
+
130
+ # Truncate or pad
131
+ if len(y) > self.max_samples:
132
+ y = y[:self.max_samples]
133
+ elif len(y) < self.max_samples:
134
+ y = np.pad(y, (0, self.max_samples - len(y)))
135
+
136
+ return torch.tensor(y, dtype=torch.float32), label
137
+
138
+
139
+ def collate_fn(batch):
140
+ """Collate audio tensors and labels."""
141
+ audios, labels = zip(*batch)
142
+ audios = torch.stack(audios)
143
+ labels = torch.tensor(labels, dtype=torch.float32)
144
+ return audios, labels
145
+
146
+
147
+ def train_wav2vec2(
148
+ manifest_csv: str | Path,
149
+ output_dir: str | Path = "models",
150
+ config: Wav2Vec2Config | None = None,
151
+ ) -> dict:
152
+ """
153
+ Fine-tune wav2vec2 on the training dataset.
154
+
155
+ Args:
156
+ manifest_csv: CSV with file_path, label_int columns.
157
+ output_dir: Directory to save trained model.
158
+ config: Training configuration.
159
+
160
+ Returns:
161
+ Dict with training metrics.
162
+ """
163
+ import csv
164
+
165
+ config = config or Wav2Vec2Config()
166
+ output_dir = Path(output_dir)
167
+ output_dir.mkdir(parents=True, exist_ok=True)
168
+
169
+ # Determine device
170
+ if config.device == "auto":
171
+ device = torch.device(
172
+ "cuda" if torch.cuda.is_available() else "cpu"
173
+ )
174
+ else:
175
+ device = torch.device(config.device)
176
+
177
+ print(f"Device: {device}")
178
+
179
+ # Load manifest
180
+ file_paths = []
181
+ labels = []
182
+ with open(manifest_csv, "r", encoding="utf-8") as f:
183
+ reader = csv.DictReader(f)
184
+ for row in reader:
185
+ file_paths.append(row["file_path"])
186
+ labels.append(int(row["label_int"]))
187
+
188
+ # Split: 80% train, 10% val, 10% test
189
+ n = len(labels)
190
+ indices = np.random.RandomState(42).permutation(n)
191
+ train_end = int(n * 0.8)
192
+ val_end = int(n * 0.9)
193
+
194
+ train_idx = indices[:train_end]
195
+ val_idx = indices[train_end:val_end]
196
+ test_idx = indices[val_end:]
197
+
198
+ train_ds = AudioDataset(
199
+ [file_paths[i] for i in train_idx],
200
+ [labels[i] for i in train_idx],
201
+ config.sample_rate,
202
+ config.max_audio_sec,
203
+ )
204
+ val_ds = AudioDataset(
205
+ [file_paths[i] for i in val_idx],
206
+ [labels[i] for i in val_idx],
207
+ config.sample_rate,
208
+ config.max_audio_sec,
209
+ )
210
+
211
+ train_loader = DataLoader(
212
+ train_ds, batch_size=config.batch_size,
213
+ shuffle=True, collate_fn=collate_fn,
214
+ num_workers=0,
215
+ )
216
+ val_loader = DataLoader(
217
+ val_ds, batch_size=config.batch_size,
218
+ shuffle=False, collate_fn=collate_fn,
219
+ num_workers=0,
220
+ )
221
+
222
+ # Build model
223
+ model = Wav2Vec2MusicClassifier(config).to(device)
224
+
225
+ # Different learning rates for encoder vs head
226
+ optimizer = torch.optim.AdamW([
227
+ {
228
+ "params": model.wav2vec2.parameters(),
229
+ "lr": config.learning_rate_encoder,
230
+ },
231
+ {
232
+ "params": model.classifier.parameters(),
233
+ "lr": config.learning_rate_head,
234
+ },
235
+ ], weight_decay=config.weight_decay)
236
+
237
+ criterion = nn.BCEWithLogitsLoss()
238
+
239
+ # Training loop
240
+ best_val_auc = 0.0
241
+ patience_counter = 0
242
+ history = []
243
+
244
+ for epoch in range(config.epochs):
245
+ model.train()
246
+ train_loss = 0.0
247
+
248
+ for batch_audio, batch_labels in train_loader:
249
+ batch_audio = batch_audio.to(device)
250
+ batch_labels = batch_labels.to(device)
251
+
252
+ optimizer.zero_grad()
253
+ logits, _ = model(batch_audio)
254
+ loss = criterion(logits.squeeze(-1), batch_labels)
255
+ loss.backward()
256
+ optimizer.step()
257
+
258
+ train_loss += loss.item()
259
+
260
+ avg_train_loss = train_loss / len(train_loader)
261
+
262
+ # Validation
263
+ model.eval()
264
+ val_probs = []
265
+ val_labels = []
266
+
267
+ with torch.no_grad():
268
+ for batch_audio, batch_labels in val_loader:
269
+ batch_audio = batch_audio.to(device)
270
+ logits, _ = model(batch_audio)
271
+ probs = torch.sigmoid(logits.squeeze(-1))
272
+ val_probs.extend(probs.cpu().numpy())
273
+ val_labels.extend(batch_labels.numpy())
274
+
275
+ val_probs = np.array(val_probs)
276
+ val_labels = np.array(val_labels)
277
+ val_preds = (val_probs > 0.5).astype(int)
278
+
279
+ from sklearn.metrics import accuracy_score, roc_auc_score
280
+ val_acc = accuracy_score(val_labels, val_preds)
281
+ val_auc = roc_auc_score(val_labels, val_probs)
282
+
283
+ print(
284
+ f"Epoch {epoch + 1}/{config.epochs} | "
285
+ f"Loss: {avg_train_loss:.4f} | "
286
+ f"Val Acc: {val_acc:.4f} | "
287
+ f"Val AUC: {val_auc:.4f}"
288
+ )
289
+
290
+ history.append({
291
+ "epoch": epoch + 1,
292
+ "train_loss": avg_train_loss,
293
+ "val_accuracy": val_acc,
294
+ "val_auc": val_auc,
295
+ })
296
+
297
+ # Early stopping
298
+ if val_auc > best_val_auc:
299
+ best_val_auc = val_auc
300
+ patience_counter = 0
301
+ # Save best model
302
+ model_path = output_dir / "wav2vec2_auris_v1.pt"
303
+ torch.save(model.state_dict(), model_path)
304
+ print(f" → Saved best model (AUC={val_auc:.4f})")
305
+ else:
306
+ patience_counter += 1
307
+ if patience_counter >= config.patience:
308
+ print(f" → Early stopping at epoch {epoch + 1}")
309
+ break
310
+
311
+ print(f"\nBest validation AUC: {best_val_auc:.4f}")
312
+ print(f"Model saved: {output_dir / 'wav2vec2_auris_v1.pt'}")
313
+
314
+ return {
315
+ "best_val_auc": best_val_auc,
316
+ "history": history,
317
+ "model_path": str(output_dir / "wav2vec2_auris_v1.pt"),
318
+ }
319
+
320
+
321
+ if __name__ == "__main__":
322
+ manifest = sys.argv[1] if len(sys.argv) > 1 else "data/sonics/manifest.csv"
323
+ out_dir = sys.argv[2] if len(sys.argv) > 2 else "models"
324
+ train_wav2vec2(manifest, out_dir)