Rthur2003 commited on
Commit
2a59987
·
1 Parent(s): 282a605

feat: refactor worker function for batch feature extraction to support multiprocessing

Browse files
app/training/extract_features_batch.py CHANGED
@@ -182,6 +182,17 @@ def extract_sample_features(audio_path: str) -> dict | None:
182
  return None
183
 
184
 
 
 
 
 
 
 
 
 
 
 
 
185
  def extract_batch(
186
  manifest_path: str | Path,
187
  output_path: str | Path | None = None,
@@ -223,15 +234,6 @@ def extract_batch(
223
 
224
  tasks = [(s["file_path"], int(s["label_int"])) for s in samples]
225
 
226
- def _worker(args):
227
- audio_path, label_int = args
228
- features = extract_sample_features(audio_path)
229
- if features is None:
230
- return None
231
- features["file_path"] = audio_path
232
- features["label_int"] = label_int
233
- return features
234
-
235
  with open(output_path, "w", newline="", encoding="utf-8") as f:
236
  writer = csv.DictWriter(f, fieldnames=out_columns)
237
  writer.writeheader()
@@ -239,7 +241,7 @@ def extract_batch(
239
 
240
  with mp.Pool(processes=n_workers) as pool:
241
  for i, result in enumerate(
242
- pool.imap_unordered(_worker, tasks, chunksize=4), 1
243
  ):
244
  if result is None:
245
  failed += 1
 
182
  return None
183
 
184
 
185
+ def _extract_worker(args: tuple[str, int]) -> dict | None:
186
+ """Module-level worker for multiprocessing (must be picklable)."""
187
+ audio_path, label_int = args
188
+ features = extract_sample_features(audio_path)
189
+ if features is None:
190
+ return None
191
+ features["file_path"] = audio_path
192
+ features["label_int"] = label_int
193
+ return features
194
+
195
+
196
  def extract_batch(
197
  manifest_path: str | Path,
198
  output_path: str | Path | None = None,
 
234
 
235
  tasks = [(s["file_path"], int(s["label_int"])) for s in samples]
236
 
 
 
 
 
 
 
 
 
 
237
  with open(output_path, "w", newline="", encoding="utf-8") as f:
238
  writer = csv.DictWriter(f, fieldnames=out_columns)
239
  writer.writeheader()
 
241
 
242
  with mp.Pool(processes=n_workers) as pool:
243
  for i, result in enumerate(
244
+ pool.imap_unordered(_extract_worker, tasks, chunksize=4), 1
245
  ):
246
  if result is None:
247
  failed += 1