crowncode-backend / app /training /extract_features_batch.py
Rthur2003's picture
feat: implement resume functionality in batch feature extraction to skip processed samples
8f111ee
"""
Batch feature extraction for AURIS training pipeline.
Runs feature_extractor and vocal_analyzer on every sample
in a manifest CSV, collecting RAW features (not heuristic
scores) into a single parquet/CSV for classifier training.
"""
from __future__ import annotations
import csv
import io
import sys
import traceback
from pathlib import Path
import logging
import numpy as np
logger = logging.getLogger(__name__)
# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from app.services.feature_extractor import extract_features
from app.services.vocal_analyzer import analyze_vocals
# All raw features we extract per sample — comprehensive set for paper
FEATURE_COLUMNS = [
# ── Basic metadata ──────────────────────────────────────────
"duration_sec",
"sample_rate",
# ── Spectral features ───────────────────────────────────────
"rms_energy",
"rms_std",
"spectral_centroid_mean",
"spectral_centroid_std",
"spectral_flatness_mean",
"spectral_flatness_std",
"spectral_bandwidth_mean",
"spectral_bandwidth_std",
"spectral_rolloff_mean",
"spectral_rolloff_std",
"spectral_contrast_mean",
"spectral_contrast_std",
"mfcc_variance",
"mfcc_delta_var",
"mfcc_delta2_var",
"mel_flatness",
# ── Temporal / rhythm features ──────────────────────────────
"tempo_bpm",
"tempo_stability",
"tempo_cv",
"zero_crossing_rate",
"zero_crossing_std",
"onset_strength_mean",
"onset_strength_std",
"rms_dynamic_range",
"beat_count",
# ── Harmonic / tonal features ───────────────────────────────
"chroma_entropy",
"chroma_std",
"chroma_transition_rate",
"harmonic_ratio",
"tonnetz_std",
# ── Heuristic composite scores (kept as features) ───────────
"spectral_regularity",
"temporal_patterns",
"harmonic_structure",
# ── Vocal analysis features ─────────────────────────────────
"has_vocals",
"vocal_confidence",
"vocal_ai_score",
"pitch_stability_score",
"vibrato_regularity_score",
"formant_consistency_score",
"breath_pattern_score",
"vocal_texture_score",
"pitch_mean_hz",
"pitch_std_cents",
"vibrato_rate_hz",
"vibrato_extent_cents",
"vocal_harmonic_ratio",
"vocal_energy_ratio",
]
def extract_sample_features(audio_path: str) -> dict | None:
"""
Extract all raw features from a single audio file.
Returns dict of feature_name -> float, or None on failure.
"""
try:
path = Path(audio_path)
# Feature extraction — all fields from AudioFeatures dataclass
feat = extract_features(path)
row = {
"duration_sec": feat.duration_sec,
"sample_rate": feat.sample_rate,
"rms_energy": feat.rms_energy,
"rms_std": feat.rms_std,
"spectral_centroid_mean": feat.spectral_centroid_mean,
"spectral_centroid_std": feat.spectral_centroid_std,
"spectral_flatness_mean": feat.spectral_flatness_mean,
"spectral_flatness_std": feat.spectral_flatness_std,
"spectral_bandwidth_mean": feat.spectral_bandwidth_mean,
"spectral_bandwidth_std": feat.spectral_bandwidth_std,
"spectral_rolloff_mean": feat.spectral_rolloff_mean,
"spectral_rolloff_std": feat.spectral_rolloff_std,
"spectral_contrast_mean": feat.spectral_contrast_mean,
"spectral_contrast_std": feat.spectral_contrast_std,
"mfcc_variance": feat.mfcc_variance,
"mfcc_delta_var": feat.mfcc_delta_var,
"mfcc_delta2_var": feat.mfcc_delta2_var,
"mel_flatness": feat.mel_flatness,
"tempo_bpm": feat.tempo_bpm,
"tempo_stability": feat.tempo_stability,
"tempo_cv": feat.tempo_cv,
"zero_crossing_rate": feat.zero_crossing_rate,
"zero_crossing_std": feat.zero_crossing_std,
"onset_strength_mean": feat.onset_strength_mean,
"onset_strength_std": feat.onset_strength_std,
"rms_dynamic_range": feat.rms_dynamic_range,
"beat_count": feat.beat_count,
"chroma_entropy": feat.chroma_entropy,
"chroma_std": feat.chroma_std,
"chroma_transition_rate": feat.chroma_transition_rate,
"harmonic_ratio": feat.harmonic_ratio,
"tonnetz_std": feat.tonnetz_std,
"spectral_regularity": feat.spectral_regularity,
"temporal_patterns": feat.temporal_patterns,
"harmonic_structure": feat.harmonic_structure,
}
# Vocal analysis
try:
vocals = analyze_vocals(path)
row.update({
"has_vocals": 1.0 if vocals.has_vocals else 0.0,
"vocal_confidence": vocals.vocal_confidence,
"vocal_ai_score": vocals.vocal_ai_score,
"pitch_stability_score": vocals.pitch_stability_score,
"vibrato_regularity_score": vocals.vibrato_regularity_score,
"formant_consistency_score": vocals.formant_consistency_score,
"breath_pattern_score": vocals.breath_pattern_score,
"vocal_texture_score": vocals.vocal_texture_score,
"pitch_mean_hz": vocals.pitch_mean_hz,
"pitch_std_cents": vocals.pitch_std_cents,
"vibrato_rate_hz": vocals.vibrato_rate_hz,
"vibrato_extent_cents": vocals.vibrato_extent_cents,
"vocal_harmonic_ratio": vocals.vocal_harmonic_ratio,
"vocal_energy_ratio": vocals.vocal_energy_ratio,
})
except Exception as e: # noqa: BLE001
logger.debug("Vocal extraction failed: %s", e)
# Fill vocal features with defaults
row.update({
"has_vocals": 0.0,
"vocal_confidence": 0.0,
"vocal_ai_score": 0.0,
"pitch_stability_score": 0.0,
"vibrato_regularity_score": 0.0,
"formant_consistency_score": 0.0,
"breath_pattern_score": 0.0,
"vocal_texture_score": 0.0,
"pitch_mean_hz": 0.0,
"pitch_std_cents": 0.0,
"vibrato_rate_hz": 0.0,
"vibrato_extent_cents": 0.0,
"vocal_harmonic_ratio": 0.0,
"vocal_energy_ratio": 0.0,
})
return row
except Exception as e:
print(f" FAILED: {audio_path}: {e}")
return None
def _extract_worker(args: tuple[str, int]) -> dict | None:
"""Module-level worker for multiprocessing (must be picklable)."""
audio_path, label_int = args
features = extract_sample_features(audio_path)
if features is None:
return None
features["file_path"] = audio_path
features["label_int"] = label_int
return features
def extract_batch(
manifest_path: str | Path,
output_path: str | Path | None = None,
) -> Path:
"""
Extract features for all samples in a manifest.
Args:
manifest_path: Path to manifest CSV with file_path, label_int.
output_path: Path for output CSV. Default: same dir, features.csv.
Returns:
Path to the output features CSV.
"""
manifest_path = Path(manifest_path)
if output_path is None:
output_path = manifest_path.parent / "features.csv"
output_path = Path(output_path)
# Read manifest
samples = []
with open(manifest_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
samples.append(row)
# Parallel processing via multiprocessing.Pool
import multiprocessing as mp
import os as _os
import time as _time
n_workers = max(1, (_os.cpu_count() or 4) - 1)
print(f"Extracting features from {len(samples)} samples using {n_workers} workers...", flush=True)
out_columns = ["file_path", "label_int"] + FEATURE_COLUMNS
success = 0
failed = 0
t_start = _time.time()
done_paths: set[str] = set()
resume = output_path.exists() and output_path.stat().st_size > 0
if resume:
with open(output_path, "r", encoding="utf-8") as f_prev:
reader = csv.DictReader(f_prev)
for r in reader:
done_paths.add(r["file_path"])
print(f" Resuming: {len(done_paths)} samples already processed, skipping", flush=True)
tasks = [
(s["file_path"], int(s["label_int"]))
for s in samples
if s["file_path"] not in done_paths
]
total_remaining = len(tasks)
print(f" Remaining: {total_remaining} samples to process", flush=True)
file_mode = "a" if resume else "w"
with open(output_path, file_mode, newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=out_columns)
if not resume:
writer.writeheader()
f.flush()
with mp.Pool(processes=n_workers) as pool:
for i, result in enumerate(
pool.imap_unordered(_extract_worker, tasks, chunksize=4), 1
):
if result is None:
failed += 1
continue
writer.writerow(result)
success += 1
if i % 25 == 0:
f.flush()
elapsed = _time.time() - t_start
rate = i / elapsed if elapsed > 0 else 0
eta = (total_remaining - i) / rate if rate > 0 else 0
print(
f" [{i}/{total_remaining}] "
f"ok={success} fail={failed} "
f"rate={rate:.1f}/s eta={eta / 60:.1f}m",
flush=True,
)
elapsed = _time.time() - t_start
print(
f"\nDone: {success} extracted, "
f"{failed} failed in {elapsed / 60:.1f}m",
flush=True,
)
print(f"Output: {output_path}", flush=True)
return output_path
if __name__ == "__main__":
manifest = sys.argv[1] if len(sys.argv) > 1 else "data/sonics/manifest.csv"
out = sys.argv[2] if len(sys.argv) > 2 else None
extract_batch(manifest, out)