Spaces:
Sleeping
Sleeping
| """ | |
| Batch feature extraction for AURIS training pipeline. | |
| Runs feature_extractor and vocal_analyzer on every sample | |
| in a manifest CSV, collecting RAW features (not heuristic | |
| scores) into a single parquet/CSV for classifier training. | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| import io | |
| import sys | |
| import traceback | |
| from pathlib import Path | |
| import logging | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| # Add parent to path for imports | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[2])) | |
| from app.services.feature_extractor import extract_features | |
| from app.services.vocal_analyzer import analyze_vocals | |
| # All raw features we extract per sample — comprehensive set for paper | |
| FEATURE_COLUMNS = [ | |
| # ── Basic metadata ────────────────────────────────────────── | |
| "duration_sec", | |
| "sample_rate", | |
| # ── Spectral features ─────────────────────────────────────── | |
| "rms_energy", | |
| "rms_std", | |
| "spectral_centroid_mean", | |
| "spectral_centroid_std", | |
| "spectral_flatness_mean", | |
| "spectral_flatness_std", | |
| "spectral_bandwidth_mean", | |
| "spectral_bandwidth_std", | |
| "spectral_rolloff_mean", | |
| "spectral_rolloff_std", | |
| "spectral_contrast_mean", | |
| "spectral_contrast_std", | |
| "mfcc_variance", | |
| "mfcc_delta_var", | |
| "mfcc_delta2_var", | |
| "mel_flatness", | |
| # ── Temporal / rhythm features ────────────────────────────── | |
| "tempo_bpm", | |
| "tempo_stability", | |
| "tempo_cv", | |
| "zero_crossing_rate", | |
| "zero_crossing_std", | |
| "onset_strength_mean", | |
| "onset_strength_std", | |
| "rms_dynamic_range", | |
| "beat_count", | |
| # ── Harmonic / tonal features ─────────────────────────────── | |
| "chroma_entropy", | |
| "chroma_std", | |
| "chroma_transition_rate", | |
| "harmonic_ratio", | |
| "tonnetz_std", | |
| # ── Heuristic composite scores (kept as features) ─────────── | |
| "spectral_regularity", | |
| "temporal_patterns", | |
| "harmonic_structure", | |
| # ── Vocal analysis features ───────────────────────────────── | |
| "has_vocals", | |
| "vocal_confidence", | |
| "vocal_ai_score", | |
| "pitch_stability_score", | |
| "vibrato_regularity_score", | |
| "formant_consistency_score", | |
| "breath_pattern_score", | |
| "vocal_texture_score", | |
| "pitch_mean_hz", | |
| "pitch_std_cents", | |
| "vibrato_rate_hz", | |
| "vibrato_extent_cents", | |
| "vocal_harmonic_ratio", | |
| "vocal_energy_ratio", | |
| ] | |
| def extract_sample_features(audio_path: str) -> dict | None: | |
| """ | |
| Extract all raw features from a single audio file. | |
| Returns dict of feature_name -> float, or None on failure. | |
| """ | |
| try: | |
| path = Path(audio_path) | |
| # Feature extraction — all fields from AudioFeatures dataclass | |
| feat = extract_features(path) | |
| row = { | |
| "duration_sec": feat.duration_sec, | |
| "sample_rate": feat.sample_rate, | |
| "rms_energy": feat.rms_energy, | |
| "rms_std": feat.rms_std, | |
| "spectral_centroid_mean": feat.spectral_centroid_mean, | |
| "spectral_centroid_std": feat.spectral_centroid_std, | |
| "spectral_flatness_mean": feat.spectral_flatness_mean, | |
| "spectral_flatness_std": feat.spectral_flatness_std, | |
| "spectral_bandwidth_mean": feat.spectral_bandwidth_mean, | |
| "spectral_bandwidth_std": feat.spectral_bandwidth_std, | |
| "spectral_rolloff_mean": feat.spectral_rolloff_mean, | |
| "spectral_rolloff_std": feat.spectral_rolloff_std, | |
| "spectral_contrast_mean": feat.spectral_contrast_mean, | |
| "spectral_contrast_std": feat.spectral_contrast_std, | |
| "mfcc_variance": feat.mfcc_variance, | |
| "mfcc_delta_var": feat.mfcc_delta_var, | |
| "mfcc_delta2_var": feat.mfcc_delta2_var, | |
| "mel_flatness": feat.mel_flatness, | |
| "tempo_bpm": feat.tempo_bpm, | |
| "tempo_stability": feat.tempo_stability, | |
| "tempo_cv": feat.tempo_cv, | |
| "zero_crossing_rate": feat.zero_crossing_rate, | |
| "zero_crossing_std": feat.zero_crossing_std, | |
| "onset_strength_mean": feat.onset_strength_mean, | |
| "onset_strength_std": feat.onset_strength_std, | |
| "rms_dynamic_range": feat.rms_dynamic_range, | |
| "beat_count": feat.beat_count, | |
| "chroma_entropy": feat.chroma_entropy, | |
| "chroma_std": feat.chroma_std, | |
| "chroma_transition_rate": feat.chroma_transition_rate, | |
| "harmonic_ratio": feat.harmonic_ratio, | |
| "tonnetz_std": feat.tonnetz_std, | |
| "spectral_regularity": feat.spectral_regularity, | |
| "temporal_patterns": feat.temporal_patterns, | |
| "harmonic_structure": feat.harmonic_structure, | |
| } | |
| # Vocal analysis | |
| try: | |
| vocals = analyze_vocals(path) | |
| row.update({ | |
| "has_vocals": 1.0 if vocals.has_vocals else 0.0, | |
| "vocal_confidence": vocals.vocal_confidence, | |
| "vocal_ai_score": vocals.vocal_ai_score, | |
| "pitch_stability_score": vocals.pitch_stability_score, | |
| "vibrato_regularity_score": vocals.vibrato_regularity_score, | |
| "formant_consistency_score": vocals.formant_consistency_score, | |
| "breath_pattern_score": vocals.breath_pattern_score, | |
| "vocal_texture_score": vocals.vocal_texture_score, | |
| "pitch_mean_hz": vocals.pitch_mean_hz, | |
| "pitch_std_cents": vocals.pitch_std_cents, | |
| "vibrato_rate_hz": vocals.vibrato_rate_hz, | |
| "vibrato_extent_cents": vocals.vibrato_extent_cents, | |
| "vocal_harmonic_ratio": vocals.vocal_harmonic_ratio, | |
| "vocal_energy_ratio": vocals.vocal_energy_ratio, | |
| }) | |
| except Exception as e: # noqa: BLE001 | |
| logger.debug("Vocal extraction failed: %s", e) | |
| # Fill vocal features with defaults | |
| row.update({ | |
| "has_vocals": 0.0, | |
| "vocal_confidence": 0.0, | |
| "vocal_ai_score": 0.0, | |
| "pitch_stability_score": 0.0, | |
| "vibrato_regularity_score": 0.0, | |
| "formant_consistency_score": 0.0, | |
| "breath_pattern_score": 0.0, | |
| "vocal_texture_score": 0.0, | |
| "pitch_mean_hz": 0.0, | |
| "pitch_std_cents": 0.0, | |
| "vibrato_rate_hz": 0.0, | |
| "vibrato_extent_cents": 0.0, | |
| "vocal_harmonic_ratio": 0.0, | |
| "vocal_energy_ratio": 0.0, | |
| }) | |
| return row | |
| except Exception as e: | |
| print(f" FAILED: {audio_path}: {e}") | |
| return None | |
| def _extract_worker(args: tuple[str, int]) -> dict | None: | |
| """Module-level worker for multiprocessing (must be picklable).""" | |
| audio_path, label_int = args | |
| features = extract_sample_features(audio_path) | |
| if features is None: | |
| return None | |
| features["file_path"] = audio_path | |
| features["label_int"] = label_int | |
| return features | |
| def extract_batch( | |
| manifest_path: str | Path, | |
| output_path: str | Path | None = None, | |
| ) -> Path: | |
| """ | |
| Extract features for all samples in a manifest. | |
| Args: | |
| manifest_path: Path to manifest CSV with file_path, label_int. | |
| output_path: Path for output CSV. Default: same dir, features.csv. | |
| Returns: | |
| Path to the output features CSV. | |
| """ | |
| manifest_path = Path(manifest_path) | |
| if output_path is None: | |
| output_path = manifest_path.parent / "features.csv" | |
| output_path = Path(output_path) | |
| # Read manifest | |
| samples = [] | |
| with open(manifest_path, "r", encoding="utf-8") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| samples.append(row) | |
| # Parallel processing via multiprocessing.Pool | |
| import multiprocessing as mp | |
| import os as _os | |
| import time as _time | |
| n_workers = max(1, (_os.cpu_count() or 4) - 1) | |
| print(f"Extracting features from {len(samples)} samples using {n_workers} workers...", flush=True) | |
| out_columns = ["file_path", "label_int"] + FEATURE_COLUMNS | |
| success = 0 | |
| failed = 0 | |
| t_start = _time.time() | |
| done_paths: set[str] = set() | |
| resume = output_path.exists() and output_path.stat().st_size > 0 | |
| if resume: | |
| with open(output_path, "r", encoding="utf-8") as f_prev: | |
| reader = csv.DictReader(f_prev) | |
| for r in reader: | |
| done_paths.add(r["file_path"]) | |
| print(f" Resuming: {len(done_paths)} samples already processed, skipping", flush=True) | |
| tasks = [ | |
| (s["file_path"], int(s["label_int"])) | |
| for s in samples | |
| if s["file_path"] not in done_paths | |
| ] | |
| total_remaining = len(tasks) | |
| print(f" Remaining: {total_remaining} samples to process", flush=True) | |
| file_mode = "a" if resume else "w" | |
| with open(output_path, file_mode, newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=out_columns) | |
| if not resume: | |
| writer.writeheader() | |
| f.flush() | |
| with mp.Pool(processes=n_workers) as pool: | |
| for i, result in enumerate( | |
| pool.imap_unordered(_extract_worker, tasks, chunksize=4), 1 | |
| ): | |
| if result is None: | |
| failed += 1 | |
| continue | |
| writer.writerow(result) | |
| success += 1 | |
| if i % 25 == 0: | |
| f.flush() | |
| elapsed = _time.time() - t_start | |
| rate = i / elapsed if elapsed > 0 else 0 | |
| eta = (total_remaining - i) / rate if rate > 0 else 0 | |
| print( | |
| f" [{i}/{total_remaining}] " | |
| f"ok={success} fail={failed} " | |
| f"rate={rate:.1f}/s eta={eta / 60:.1f}m", | |
| flush=True, | |
| ) | |
| elapsed = _time.time() - t_start | |
| print( | |
| f"\nDone: {success} extracted, " | |
| f"{failed} failed in {elapsed / 60:.1f}m", | |
| flush=True, | |
| ) | |
| print(f"Output: {output_path}", flush=True) | |
| return output_path | |
| if __name__ == "__main__": | |
| manifest = sys.argv[1] if len(sys.argv) > 1 else "data/sonics/manifest.csv" | |
| out = sys.argv[2] if len(sys.argv) > 2 else None | |
| extract_batch(manifest, out) | |