#!/usr/bin/env python3 """拆分 unified_spatial_foa_fsd63_all/train.jsonl 按 data_source 字段分成三份。 供 v13_C 实验使用:real replication 需要把 dcase_real 单独 manifest,以便 train_manifest_replication=(1, 1, 6) 让 real 占比 6% → ~25%。 用法: python scripts/split_unified_train_by_source.py \\ --input /apdcephfs_cq12/.../unified_spatial_foa_fsd63_all/train.jsonl \\ --output-dir /apdcephfs_cq12/.../unified_spatial_foa_fsd63_all # dry-run 只统计不写文件 python scripts/split_unified_train_by_source.py \\ --input /apdcephfs_cq12/.../train.jsonl --dry-run """ from __future__ import annotations import argparse import json from collections import Counter from pathlib import Path from typing import Dict, List DEFAULT_INPUT = ( "/apdcephfs_cq12/share_302080740/user/schmittzhu/data/" "unified_spatial_foa_fsd63_all/train.jsonl" ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--input", default=DEFAULT_INPUT) p.add_argument( "--output-dir", default=None, help="默认 = 输入文件所在目录", ) p.add_argument( "--dry-run", action="store_true", help="只统计不写文件", ) return p.parse_args() def main() -> None: args = parse_args() input_path = Path(args.input) assert input_path.exists(), f"Input not found: {input_path}" output_dir = Path(args.output_dir) if args.output_dir else input_path.parent output_dir.mkdir(parents=True, exist_ok=True) # 三个输出 manifest out_paths: Dict[str, Path] = { "sim_static": output_dir / "train_sim_static.jsonl", "qa_sim": output_dir / "train_qa_sim.jsonl", "dcase_real": output_dir / "train_dcase_real.jsonl", } counter = Counter() unknown = Counter() # 统计 with open(input_path) as f: for line in f: try: d = json.loads(line) except Exception: continue src = str(d.get("data_source", "")) if src in out_paths: counter[src] += 1 else: unknown[src] += 1 total_known = sum(counter.values()) total_unknown = sum(unknown.values()) print("=" * 60) print(f" Input: {input_path}") print(f" Total lines: {total_known + total_unknown}") print(f" Known data_sources:") for k, v in counter.most_common(): print(f" {k:<14s}: {v:>7d} ({100*v/(total_known+total_unknown):.2f}%)") if unknown: print(f" UNKNOWN data_sources (will be DROPPED):") for k, v in unknown.most_common(): print(f" {repr(k):<20s}: {v:>7d}") print("=" * 60) if args.dry_run: print("[dry-run] 不写文件") return # 写三份 writers = {k: open(p, "w") for k, p in out_paths.items()} try: written = Counter() with open(input_path) as f: for line in f: try: d = json.loads(line) except Exception: continue src = str(d.get("data_source", "")) if src in writers: writers[src].write(line) written[src] += 1 finally: for w in writers.values(): w.close() print("Written:") for k, p in out_paths.items(): print(f" {k:<14s} → {p} ({written[k]} lines)") print("Done.") if __name__ == "__main__": main()