Rthur2003 commited on
Commit
3ad0b90
·
1 Parent(s): 618303c

feat: implement wav2vec2 model loading and inference functions for audio processing

Browse files
Files changed (1) hide show
  1. local_demo.py +41 -0
local_demo.py CHANGED
@@ -166,6 +166,47 @@ def _is_model_compatible(model: Any, n_features: int) -> bool:
166
  return expected in (None, n_features)
167
 
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  def _load_artifacts() -> DemoArtifacts:
170
  scaler_path = MODELS_DIR / "feature_scaler_v1.pkl"
171
  columns_path = MODELS_DIR / "feature_columns_v1.json"
 
166
  return expected in (None, n_features)
167
 
168
 
169
+ def _load_wav2vec2() -> Any:
170
+ """Load trained wav2vec2 model from .pt checkpoint. Returns None if unavailable."""
171
+ import torch
172
+ pt_path = MODELS_DIR / "wav2vec2_auris_v1.pt"
173
+ if not pt_path.exists():
174
+ return None
175
+ try:
176
+ config = Wav2Vec2Config()
177
+ model = Wav2Vec2MusicClassifier(config)
178
+ state = torch.load(str(pt_path), map_location="cpu", weights_only=True)
179
+ model.load_state_dict(state)
180
+ model.eval()
181
+ print(f"wav2vec2 loaded: {pt_path.name}")
182
+ return model
183
+ except Exception as exc: # noqa: BLE001
184
+ print(f"wav2vec2 skipped ({exc})")
185
+ return None
186
+
187
+
188
+ def _wav2vec2_predict(model: Any, audio_path: str) -> float | None:
189
+ """Run wav2vec2 inference on a raw audio file. Returns AI probability or None."""
190
+ import torch
191
+ try:
192
+ import librosa
193
+ config: Wav2Vec2Config = model.config
194
+ y, _ = librosa.load(audio_path, sr=config.sample_rate, mono=True)
195
+ max_samples = int(config.max_audio_sec * config.sample_rate)
196
+ if len(y) > max_samples:
197
+ y = y[:max_samples]
198
+ elif len(y) < max_samples:
199
+ import numpy as _np
200
+ y = _np.pad(y, (0, max_samples - len(y)))
201
+ tensor = torch.tensor(y, dtype=torch.float32).unsqueeze(0) # (1, samples)
202
+ with torch.no_grad():
203
+ probs = model.predict_proba(tensor)
204
+ return float(probs[0])
205
+ except Exception as exc: # noqa: BLE001
206
+ print(f"wav2vec2 inference failed: {exc}")
207
+ return None
208
+
209
+
210
  def _load_artifacts() -> DemoArtifacts:
211
  scaler_path = MODELS_DIR / "feature_scaler_v1.pkl"
212
  columns_path = MODELS_DIR / "feature_columns_v1.json"