Spaces:
Sleeping
Sleeping
feat: implement wav2vec2 model loading and inference functions for audio processing
Browse files- local_demo.py +41 -0
local_demo.py
CHANGED
|
@@ -166,6 +166,47 @@ def _is_model_compatible(model: Any, n_features: int) -> bool:
|
|
| 166 |
return expected in (None, n_features)
|
| 167 |
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
def _load_artifacts() -> DemoArtifacts:
|
| 170 |
scaler_path = MODELS_DIR / "feature_scaler_v1.pkl"
|
| 171 |
columns_path = MODELS_DIR / "feature_columns_v1.json"
|
|
|
|
| 166 |
return expected in (None, n_features)
|
| 167 |
|
| 168 |
|
| 169 |
+
def _load_wav2vec2() -> Any:
|
| 170 |
+
"""Load trained wav2vec2 model from .pt checkpoint. Returns None if unavailable."""
|
| 171 |
+
import torch
|
| 172 |
+
pt_path = MODELS_DIR / "wav2vec2_auris_v1.pt"
|
| 173 |
+
if not pt_path.exists():
|
| 174 |
+
return None
|
| 175 |
+
try:
|
| 176 |
+
config = Wav2Vec2Config()
|
| 177 |
+
model = Wav2Vec2MusicClassifier(config)
|
| 178 |
+
state = torch.load(str(pt_path), map_location="cpu", weights_only=True)
|
| 179 |
+
model.load_state_dict(state)
|
| 180 |
+
model.eval()
|
| 181 |
+
print(f"wav2vec2 loaded: {pt_path.name}")
|
| 182 |
+
return model
|
| 183 |
+
except Exception as exc: # noqa: BLE001
|
| 184 |
+
print(f"wav2vec2 skipped ({exc})")
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _wav2vec2_predict(model: Any, audio_path: str) -> float | None:
|
| 189 |
+
"""Run wav2vec2 inference on a raw audio file. Returns AI probability or None."""
|
| 190 |
+
import torch
|
| 191 |
+
try:
|
| 192 |
+
import librosa
|
| 193 |
+
config: Wav2Vec2Config = model.config
|
| 194 |
+
y, _ = librosa.load(audio_path, sr=config.sample_rate, mono=True)
|
| 195 |
+
max_samples = int(config.max_audio_sec * config.sample_rate)
|
| 196 |
+
if len(y) > max_samples:
|
| 197 |
+
y = y[:max_samples]
|
| 198 |
+
elif len(y) < max_samples:
|
| 199 |
+
import numpy as _np
|
| 200 |
+
y = _np.pad(y, (0, max_samples - len(y)))
|
| 201 |
+
tensor = torch.tensor(y, dtype=torch.float32).unsqueeze(0) # (1, samples)
|
| 202 |
+
with torch.no_grad():
|
| 203 |
+
probs = model.predict_proba(tensor)
|
| 204 |
+
return float(probs[0])
|
| 205 |
+
except Exception as exc: # noqa: BLE001
|
| 206 |
+
print(f"wav2vec2 inference failed: {exc}")
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
|
| 210 |
def _load_artifacts() -> DemoArtifacts:
|
| 211 |
scaler_path = MODELS_DIR / "feature_scaler_v1.pkl"
|
| 212 |
columns_path = MODELS_DIR / "feature_columns_v1.json"
|