Rthur2003 commited on
Commit
282a605
·
1 Parent(s): f11faed

feat: implement parallel processing for batch feature extraction to improve performance

Browse files
app/training/extract_features_batch.py CHANGED
@@ -208,41 +208,64 @@ def extract_batch(
208
  for row in reader:
209
  samples.append(row)
210
 
211
- print(f"Extracting features from {len(samples)} samples...")
 
 
 
 
 
 
212
 
213
  out_columns = ["file_path", "label_int"] + FEATURE_COLUMNS
214
  success = 0
215
  failed = 0
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  with open(output_path, "w", newline="", encoding="utf-8") as f:
218
  writer = csv.DictWriter(f, fieldnames=out_columns)
219
  writer.writeheader()
220
-
221
- for i, sample in enumerate(samples):
222
- audio_path = sample["file_path"]
223
- label_int = int(sample["label_int"])
224
-
225
- features = extract_sample_features(audio_path)
226
- if features is None:
227
- failed += 1
228
- continue
229
-
230
- features["file_path"] = audio_path
231
- features["label_int"] = label_int
232
- writer.writerow(features)
233
- success += 1
234
-
235
- if (i + 1) % 50 == 0:
236
- print(
237
- f" [{i + 1}/{len(samples)}] "
238
- f"success={success}, failed={failed}"
239
- )
240
-
 
 
 
 
241
  print(
242
  f"\nDone: {success} extracted, "
243
- f"{failed} failed"
 
244
  )
245
- print(f"Output: {output_path}")
246
 
247
  return output_path
248
 
 
208
  for row in reader:
209
  samples.append(row)
210
 
211
+ # Parallel processing via multiprocessing.Pool
212
+ import multiprocessing as mp
213
+ import os as _os
214
+ import time as _time
215
+
216
+ n_workers = max(1, (_os.cpu_count() or 4) - 1)
217
+ print(f"Extracting features from {len(samples)} samples using {n_workers} workers...", flush=True)
218
 
219
  out_columns = ["file_path", "label_int"] + FEATURE_COLUMNS
220
  success = 0
221
  failed = 0
222
+ t_start = _time.time()
223
+
224
+ tasks = [(s["file_path"], int(s["label_int"])) for s in samples]
225
+
226
+ def _worker(args):
227
+ audio_path, label_int = args
228
+ features = extract_sample_features(audio_path)
229
+ if features is None:
230
+ return None
231
+ features["file_path"] = audio_path
232
+ features["label_int"] = label_int
233
+ return features
234
 
235
  with open(output_path, "w", newline="", encoding="utf-8") as f:
236
  writer = csv.DictWriter(f, fieldnames=out_columns)
237
  writer.writeheader()
238
+ f.flush()
239
+
240
+ with mp.Pool(processes=n_workers) as pool:
241
+ for i, result in enumerate(
242
+ pool.imap_unordered(_worker, tasks, chunksize=4), 1
243
+ ):
244
+ if result is None:
245
+ failed += 1
246
+ continue
247
+ writer.writerow(result)
248
+ success += 1
249
+
250
+ if i % 25 == 0:
251
+ f.flush()
252
+ elapsed = _time.time() - t_start
253
+ rate = i / elapsed if elapsed > 0 else 0
254
+ eta = (len(samples) - i) / rate if rate > 0 else 0
255
+ print(
256
+ f" [{i}/{len(samples)}] "
257
+ f"ok={success} fail={failed} "
258
+ f"rate={rate:.1f}/s eta={eta / 60:.1f}m",
259
+ flush=True,
260
+ )
261
+
262
+ elapsed = _time.time() - t_start
263
  print(
264
  f"\nDone: {success} extracted, "
265
+ f"{failed} failed in {elapsed / 60:.1f}m",
266
+ flush=True,
267
  )
268
+ print(f"Output: {output_path}", flush=True)
269
 
270
  return output_path
271