Spaces:
Sleeping
Sleeping
| """ | |
| wav2vec2 fine-tuned classifier for AI music detection. | |
| Architecture: | |
| wav2vec2-base (frozen CNN encoder, trainable transformer) | |
| → Global Average Pooling (768-dim) | |
| → Linear(768, 256) + ReLU + Dropout(0.3) | |
| → Linear(256, 1) | |
| → Binary: AI (1) vs Human (0) | |
| This module defines both the model and the training loop. | |
| Requires GPU for training (~2-4 hours on T4 for 10K samples). | |
| CPU inference: ~0.5s for 30s audio. | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| class Wav2Vec2Config: | |
| """Training configuration.""" | |
| model_name: str = "facebook/wav2vec2-base" | |
| max_audio_sec: float = 30.0 | |
| sample_rate: int = 16000 | |
| hidden_dim: int = 256 | |
| dropout: float = 0.3 | |
| learning_rate_head: float = 1e-3 | |
| learning_rate_encoder: float = 1e-5 | |
| weight_decay: float = 0.01 | |
| batch_size: int = 2 | |
| epochs: int = 10 | |
| patience: int = 3 | |
| device: str = "auto" | |
| class Wav2Vec2MusicClassifier(nn.Module): | |
| """ | |
| wav2vec2 with classification head for AI music detection. | |
| The CNN feature encoder is frozen (robust low-level audio | |
| representation). The transformer layers are fine-tuned to | |
| learn task-specific temporal patterns. | |
| """ | |
| def __init__(self, config: Wav2Vec2Config | None = None) -> None: | |
| """Initialize wav2vec2 classifier with frozen CNN encoder.""" | |
| super().__init__() | |
| self.config = config or Wav2Vec2Config() | |
| from transformers import Wav2Vec2Model | |
| self.wav2vec2 = Wav2Vec2Model.from_pretrained( | |
| self.config.model_name | |
| ) | |
| # Freeze CNN encoder, fine-tune transformer | |
| self.wav2vec2.feature_extractor._freeze_parameters() | |
| self.classifier = nn.Sequential( | |
| nn.Linear(768, self.config.hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(self.config.dropout), | |
| nn.Linear(self.config.hidden_dim, 1), | |
| ) | |
| def forward( | |
| self, input_values: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Forward pass. | |
| Args: | |
| input_values: (batch, samples) raw audio waveform at 16kHz. | |
| Returns: | |
| logits: (batch, 1) classification logit. | |
| hidden: (batch, 768) pooled hidden states for meta-classifier. | |
| """ | |
| outputs = self.wav2vec2(input_values) | |
| # Mean pool over time dimension | |
| hidden = outputs.last_hidden_state.mean(dim=1) # (batch, 768) | |
| logits = self.classifier(hidden) # (batch, 1) | |
| return logits, hidden | |
| def predict_proba(self, input_values: torch.Tensor) -> np.ndarray: | |
| """Get probability of AI-generated class.""" | |
| self.eval() | |
| with torch.no_grad(): | |
| logits, _ = self(input_values) | |
| probs = torch.sigmoid(logits).cpu().numpy().flatten() | |
| return probs | |
| class AudioDataset(Dataset): | |
| """Simple dataset that loads audio files and labels.""" | |
| def __init__( | |
| self, | |
| file_paths: list[str], | |
| labels: list[int], | |
| sample_rate: int = 16000, | |
| max_sec: float = 30.0, | |
| ) -> None: | |
| """Initialize audio dataset with file paths and labels.""" | |
| self.file_paths = file_paths | |
| self.labels = labels | |
| self.sample_rate = sample_rate | |
| self.max_samples = int(max_sec * sample_rate) | |
| def __len__(self) -> int: | |
| """Return dataset size.""" | |
| return len(self.file_paths) | |
| def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]: | |
| import librosa | |
| path = self.file_paths[idx] | |
| label = self.labels[idx] | |
| # Load at 16kHz for wav2vec2 | |
| y, _ = librosa.load(path, sr=self.sample_rate, mono=True) | |
| # Truncate or pad | |
| if len(y) > self.max_samples: | |
| y = y[:self.max_samples] | |
| elif len(y) < self.max_samples: | |
| y = np.pad(y, (0, self.max_samples - len(y))) | |
| return torch.tensor(y, dtype=torch.float32), label | |
| def collate_fn( | |
| batch: list[tuple[torch.Tensor, int]], | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Collate audio tensors and labels.""" | |
| audios, labels = zip(*batch) | |
| audios = torch.stack(audios) | |
| labels = torch.tensor(labels, dtype=torch.float32) | |
| return audios, labels | |
| def train_wav2vec2( | |
| manifest_csv: str | Path, | |
| output_dir: str | Path = "models", | |
| config: Wav2Vec2Config | None = None, | |
| ) -> dict: | |
| """ | |
| Fine-tune wav2vec2 on the training dataset. | |
| Args: | |
| manifest_csv: CSV with file_path, label_int columns. | |
| output_dir: Directory to save trained model. | |
| config: Training configuration. | |
| Returns: | |
| Dict with training metrics. | |
| """ | |
| import csv | |
| config = config or Wav2Vec2Config() | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Determine device | |
| if config.device == "auto": | |
| device = torch.device( | |
| "cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| else: | |
| device = torch.device(config.device) | |
| print(f"Device: {device}") | |
| # Load manifest | |
| file_paths = [] | |
| labels = [] | |
| with open(manifest_csv, "r", encoding="utf-8") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| file_paths.append(row["file_path"]) | |
| labels.append(int(row["label_int"])) | |
| # Split: 80% train, 10% val, 10% test | |
| n = len(labels) | |
| indices = np.random.RandomState(42).permutation(n) | |
| train_end = int(n * 0.8) | |
| val_end = int(n * 0.9) | |
| train_idx = indices[:train_end] | |
| val_idx = indices[train_end:val_end] | |
| test_idx = indices[val_end:] | |
| train_ds = AudioDataset( | |
| [file_paths[i] for i in train_idx], | |
| [labels[i] for i in train_idx], | |
| config.sample_rate, | |
| config.max_audio_sec, | |
| ) | |
| val_ds = AudioDataset( | |
| [file_paths[i] for i in val_idx], | |
| [labels[i] for i in val_idx], | |
| config.sample_rate, | |
| config.max_audio_sec, | |
| ) | |
| train_loader = DataLoader( | |
| train_ds, batch_size=config.batch_size, | |
| shuffle=True, collate_fn=collate_fn, | |
| num_workers=0, | |
| ) | |
| val_loader = DataLoader( | |
| val_ds, batch_size=config.batch_size, | |
| shuffle=False, collate_fn=collate_fn, | |
| num_workers=0, | |
| ) | |
| # Build model | |
| model = Wav2Vec2MusicClassifier(config).to(device) | |
| # Different learning rates for encoder vs head | |
| optimizer = torch.optim.AdamW([ | |
| { | |
| "params": model.wav2vec2.parameters(), | |
| "lr": config.learning_rate_encoder, | |
| }, | |
| { | |
| "params": model.classifier.parameters(), | |
| "lr": config.learning_rate_head, | |
| }, | |
| ], weight_decay=config.weight_decay) | |
| criterion = nn.BCEWithLogitsLoss() | |
| # Training loop with mixed precision + gradient accumulation | |
| best_val_auc = 0.0 | |
| patience_counter = 0 | |
| history = [] | |
| scaler_amp = torch.amp.GradScaler("cuda") if device.type == "cuda" else None | |
| accum_steps = 4 # effective batch = batch_size * accum_steps | |
| for epoch in range(config.epochs): | |
| model.train() | |
| train_loss = 0.0 | |
| optimizer.zero_grad() | |
| for step, (batch_audio, batch_labels) in enumerate(train_loader): | |
| batch_audio = batch_audio.to(device) | |
| batch_labels = batch_labels.to(device) | |
| if scaler_amp is not None: | |
| with torch.amp.autocast("cuda"): | |
| logits, _ = model(batch_audio) | |
| loss = criterion(logits.squeeze(-1), batch_labels) / accum_steps | |
| scaler_amp.scale(loss).backward() | |
| if (step + 1) % accum_steps == 0 or (step + 1) == len(train_loader): | |
| scaler_amp.step(optimizer) | |
| scaler_amp.update() | |
| optimizer.zero_grad() | |
| else: | |
| logits, _ = model(batch_audio) | |
| loss = criterion(logits.squeeze(-1), batch_labels) / accum_steps | |
| loss.backward() | |
| if (step + 1) % accum_steps == 0 or (step + 1) == len(train_loader): | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| train_loss += loss.item() * accum_steps | |
| avg_train_loss = train_loss / len(train_loader) | |
| # Validation | |
| model.eval() | |
| val_probs = [] | |
| val_labels = [] | |
| with torch.no_grad(): | |
| for batch_audio, batch_labels in val_loader: | |
| batch_audio = batch_audio.to(device) | |
| if scaler_amp is not None: | |
| with torch.amp.autocast("cuda"): | |
| logits, _ = model(batch_audio) | |
| else: | |
| logits, _ = model(batch_audio) | |
| probs = torch.sigmoid(logits.squeeze(-1)) | |
| val_probs.extend(probs.cpu().numpy()) | |
| val_labels.extend(batch_labels.numpy()) | |
| val_probs = np.array(val_probs) | |
| val_labels = np.array(val_labels) | |
| val_preds = (val_probs > 0.5).astype(int) | |
| from sklearn.metrics import accuracy_score, roc_auc_score | |
| val_acc = accuracy_score(val_labels, val_preds) | |
| val_auc = roc_auc_score(val_labels, val_probs) | |
| print( | |
| f"Epoch {epoch + 1}/{config.epochs} | " | |
| f"Loss: {avg_train_loss:.4f} | " | |
| f"Val Acc: {val_acc:.4f} | " | |
| f"Val AUC: {val_auc:.4f}", | |
| flush=True, | |
| ) | |
| history.append({ | |
| "epoch": epoch + 1, | |
| "train_loss": avg_train_loss, | |
| "val_accuracy": val_acc, | |
| "val_auc": val_auc, | |
| }) | |
| # Early stopping | |
| if val_auc > best_val_auc: | |
| best_val_auc = val_auc | |
| patience_counter = 0 | |
| # Save best model | |
| model_path = output_dir / "wav2vec2_auris_v1.pt" | |
| torch.save(model.state_dict(), model_path) | |
| print(f" → Saved best model (AUC={val_auc:.4f})") | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= config.patience: | |
| print(f" → Early stopping at epoch {epoch + 1}") | |
| break | |
| print(f"\nBest validation AUC: {best_val_auc:.4f}") | |
| print(f"Model saved: {output_dir / 'wav2vec2_auris_v1.pt'}") | |
| return { | |
| "best_val_auc": best_val_auc, | |
| "history": history, | |
| "model_path": str(output_dir / "wav2vec2_auris_v1.pt"), | |
| } | |
| if __name__ == "__main__": | |
| manifest = sys.argv[1] if len(sys.argv) > 1 else "data/sonics/manifest.csv" | |
| out_dir = sys.argv[2] if len(sys.argv) > 2 else "models" | |
| train_wav2vec2(manifest, out_dir) | |