Spaces:
Sleeping
Sleeping
feat: refactor worker function for batch feature extraction to support multiprocessing
Browse files
app/training/extract_features_batch.py
CHANGED
|
@@ -182,6 +182,17 @@ def extract_sample_features(audio_path: str) -> dict | None:
|
|
| 182 |
return None
|
| 183 |
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
def extract_batch(
|
| 186 |
manifest_path: str | Path,
|
| 187 |
output_path: str | Path | None = None,
|
|
@@ -223,15 +234,6 @@ def extract_batch(
|
|
| 223 |
|
| 224 |
tasks = [(s["file_path"], int(s["label_int"])) for s in samples]
|
| 225 |
|
| 226 |
-
def _worker(args):
|
| 227 |
-
audio_path, label_int = args
|
| 228 |
-
features = extract_sample_features(audio_path)
|
| 229 |
-
if features is None:
|
| 230 |
-
return None
|
| 231 |
-
features["file_path"] = audio_path
|
| 232 |
-
features["label_int"] = label_int
|
| 233 |
-
return features
|
| 234 |
-
|
| 235 |
with open(output_path, "w", newline="", encoding="utf-8") as f:
|
| 236 |
writer = csv.DictWriter(f, fieldnames=out_columns)
|
| 237 |
writer.writeheader()
|
|
@@ -239,7 +241,7 @@ def extract_batch(
|
|
| 239 |
|
| 240 |
with mp.Pool(processes=n_workers) as pool:
|
| 241 |
for i, result in enumerate(
|
| 242 |
-
pool.imap_unordered(
|
| 243 |
):
|
| 244 |
if result is None:
|
| 245 |
failed += 1
|
|
|
|
| 182 |
return None
|
| 183 |
|
| 184 |
|
| 185 |
+
def _extract_worker(args: tuple[str, int]) -> dict | None:
|
| 186 |
+
"""Module-level worker for multiprocessing (must be picklable)."""
|
| 187 |
+
audio_path, label_int = args
|
| 188 |
+
features = extract_sample_features(audio_path)
|
| 189 |
+
if features is None:
|
| 190 |
+
return None
|
| 191 |
+
features["file_path"] = audio_path
|
| 192 |
+
features["label_int"] = label_int
|
| 193 |
+
return features
|
| 194 |
+
|
| 195 |
+
|
| 196 |
def extract_batch(
|
| 197 |
manifest_path: str | Path,
|
| 198 |
output_path: str | Path | None = None,
|
|
|
|
| 234 |
|
| 235 |
tasks = [(s["file_path"], int(s["label_int"])) for s in samples]
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
with open(output_path, "w", newline="", encoding="utf-8") as f:
|
| 238 |
writer = csv.DictWriter(f, fieldnames=out_columns)
|
| 239 |
writer.writeheader()
|
|
|
|
| 241 |
|
| 242 |
with mp.Pool(processes=n_workers) as pool:
|
| 243 |
for i, result in enumerate(
|
| 244 |
+
pool.imap_unordered(_extract_worker, tasks, chunksize=4), 1
|
| 245 |
):
|
| 246 |
if result is None:
|
| 247 |
failed += 1
|