| |
| """拆分 unified_spatial_foa_fsd63_all/train.jsonl 按 data_source 字段分成三份。 |
| |
| 供 v13_C 实验使用:real replication 需要把 dcase_real 单独 manifest,以便 |
| train_manifest_replication=(1, 1, 6) 让 real 占比 6% → ~25%。 |
| |
| 用法: |
| python scripts/split_unified_train_by_source.py \\ |
| --input /apdcephfs_cq12/.../unified_spatial_foa_fsd63_all/train.jsonl \\ |
| --output-dir /apdcephfs_cq12/.../unified_spatial_foa_fsd63_all |
| |
| # dry-run 只统计不写文件 |
| python scripts/split_unified_train_by_source.py \\ |
| --input /apdcephfs_cq12/.../train.jsonl --dry-run |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from collections import Counter |
| from pathlib import Path |
| from typing import Dict, List |
|
|
|
|
| DEFAULT_INPUT = ( |
| "/apdcephfs_cq12/share_302080740/user/schmittzhu/data/" |
| "unified_spatial_foa_fsd63_all/train.jsonl" |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser() |
| p.add_argument("--input", default=DEFAULT_INPUT) |
| p.add_argument( |
| "--output-dir", |
| default=None, |
| help="默认 = 输入文件所在目录", |
| ) |
| p.add_argument( |
| "--dry-run", |
| action="store_true", |
| help="只统计不写文件", |
| ) |
| return p.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| input_path = Path(args.input) |
| assert input_path.exists(), f"Input not found: {input_path}" |
|
|
| output_dir = Path(args.output_dir) if args.output_dir else input_path.parent |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| out_paths: Dict[str, Path] = { |
| "sim_static": output_dir / "train_sim_static.jsonl", |
| "qa_sim": output_dir / "train_qa_sim.jsonl", |
| "dcase_real": output_dir / "train_dcase_real.jsonl", |
| } |
| counter = Counter() |
| unknown = Counter() |
|
|
| |
| with open(input_path) as f: |
| for line in f: |
| try: |
| d = json.loads(line) |
| except Exception: |
| continue |
| src = str(d.get("data_source", "")) |
| if src in out_paths: |
| counter[src] += 1 |
| else: |
| unknown[src] += 1 |
| total_known = sum(counter.values()) |
| total_unknown = sum(unknown.values()) |
|
|
| print("=" * 60) |
| print(f" Input: {input_path}") |
| print(f" Total lines: {total_known + total_unknown}") |
| print(f" Known data_sources:") |
| for k, v in counter.most_common(): |
| print(f" {k:<14s}: {v:>7d} ({100*v/(total_known+total_unknown):.2f}%)") |
| if unknown: |
| print(f" UNKNOWN data_sources (will be DROPPED):") |
| for k, v in unknown.most_common(): |
| print(f" {repr(k):<20s}: {v:>7d}") |
| print("=" * 60) |
|
|
| if args.dry_run: |
| print("[dry-run] 不写文件") |
| return |
|
|
| |
| writers = {k: open(p, "w") for k, p in out_paths.items()} |
| try: |
| written = Counter() |
| with open(input_path) as f: |
| for line in f: |
| try: |
| d = json.loads(line) |
| except Exception: |
| continue |
| src = str(d.get("data_source", "")) |
| if src in writers: |
| writers[src].write(line) |
| written[src] += 1 |
| finally: |
| for w in writers.values(): |
| w.close() |
|
|
| print("Written:") |
| for k, p in out_paths.items(): |
| print(f" {k:<14s} → {p} ({written[k]} lines)") |
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|