Spaces:
Sleeping
Sleeping
feat: implement parallel processing for batch feature extraction to improve performance
Browse files
app/training/extract_features_batch.py
CHANGED
|
@@ -208,41 +208,64 @@ def extract_batch(
|
|
| 208 |
for row in reader:
|
| 209 |
samples.append(row)
|
| 210 |
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
out_columns = ["file_path", "label_int"] + FEATURE_COLUMNS
|
| 214 |
success = 0
|
| 215 |
failed = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
with open(output_path, "w", newline="", encoding="utf-8") as f:
|
| 218 |
writer = csv.DictWriter(f, fieldnames=out_columns)
|
| 219 |
writer.writeheader()
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
print(
|
| 242 |
f"\nDone: {success} extracted, "
|
| 243 |
-
f"{failed} failed"
|
|
|
|
| 244 |
)
|
| 245 |
-
print(f"Output: {output_path}")
|
| 246 |
|
| 247 |
return output_path
|
| 248 |
|
|
|
|
| 208 |
for row in reader:
|
| 209 |
samples.append(row)
|
| 210 |
|
| 211 |
+
# Parallel processing via multiprocessing.Pool
|
| 212 |
+
import multiprocessing as mp
|
| 213 |
+
import os as _os
|
| 214 |
+
import time as _time
|
| 215 |
+
|
| 216 |
+
n_workers = max(1, (_os.cpu_count() or 4) - 1)
|
| 217 |
+
print(f"Extracting features from {len(samples)} samples using {n_workers} workers...", flush=True)
|
| 218 |
|
| 219 |
out_columns = ["file_path", "label_int"] + FEATURE_COLUMNS
|
| 220 |
success = 0
|
| 221 |
failed = 0
|
| 222 |
+
t_start = _time.time()
|
| 223 |
+
|
| 224 |
+
tasks = [(s["file_path"], int(s["label_int"])) for s in samples]
|
| 225 |
+
|
| 226 |
+
def _worker(args):
|
| 227 |
+
audio_path, label_int = args
|
| 228 |
+
features = extract_sample_features(audio_path)
|
| 229 |
+
if features is None:
|
| 230 |
+
return None
|
| 231 |
+
features["file_path"] = audio_path
|
| 232 |
+
features["label_int"] = label_int
|
| 233 |
+
return features
|
| 234 |
|
| 235 |
with open(output_path, "w", newline="", encoding="utf-8") as f:
|
| 236 |
writer = csv.DictWriter(f, fieldnames=out_columns)
|
| 237 |
writer.writeheader()
|
| 238 |
+
f.flush()
|
| 239 |
+
|
| 240 |
+
with mp.Pool(processes=n_workers) as pool:
|
| 241 |
+
for i, result in enumerate(
|
| 242 |
+
pool.imap_unordered(_worker, tasks, chunksize=4), 1
|
| 243 |
+
):
|
| 244 |
+
if result is None:
|
| 245 |
+
failed += 1
|
| 246 |
+
continue
|
| 247 |
+
writer.writerow(result)
|
| 248 |
+
success += 1
|
| 249 |
+
|
| 250 |
+
if i % 25 == 0:
|
| 251 |
+
f.flush()
|
| 252 |
+
elapsed = _time.time() - t_start
|
| 253 |
+
rate = i / elapsed if elapsed > 0 else 0
|
| 254 |
+
eta = (len(samples) - i) / rate if rate > 0 else 0
|
| 255 |
+
print(
|
| 256 |
+
f" [{i}/{len(samples)}] "
|
| 257 |
+
f"ok={success} fail={failed} "
|
| 258 |
+
f"rate={rate:.1f}/s eta={eta / 60:.1f}m",
|
| 259 |
+
flush=True,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
elapsed = _time.time() - t_start
|
| 263 |
print(
|
| 264 |
f"\nDone: {success} extracted, "
|
| 265 |
+
f"{failed} failed in {elapsed / 60:.1f}m",
|
| 266 |
+
flush=True,
|
| 267 |
)
|
| 268 |
+
print(f"Output: {output_path}", flush=True)
|
| 269 |
|
| 270 |
return output_path
|
| 271 |
|