File size: 3,571 Bytes
29615e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | #!/usr/bin/env python3
"""拆分 unified_spatial_foa_fsd63_all/train.jsonl 按 data_source 字段分成三份。
供 v13_C 实验使用:real replication 需要把 dcase_real 单独 manifest,以便
train_manifest_replication=(1, 1, 6) 让 real 占比 6% → ~25%。
用法:
python scripts/split_unified_train_by_source.py \\
--input /apdcephfs_cq12/.../unified_spatial_foa_fsd63_all/train.jsonl \\
--output-dir /apdcephfs_cq12/.../unified_spatial_foa_fsd63_all
# dry-run 只统计不写文件
python scripts/split_unified_train_by_source.py \\
--input /apdcephfs_cq12/.../train.jsonl --dry-run
"""
from __future__ import annotations
import argparse
import json
from collections import Counter
from pathlib import Path
from typing import Dict, List
DEFAULT_INPUT = (
"/apdcephfs_cq12/share_302080740/user/schmittzhu/data/"
"unified_spatial_foa_fsd63_all/train.jsonl"
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--input", default=DEFAULT_INPUT)
p.add_argument(
"--output-dir",
default=None,
help="默认 = 输入文件所在目录",
)
p.add_argument(
"--dry-run",
action="store_true",
help="只统计不写文件",
)
return p.parse_args()
def main() -> None:
args = parse_args()
input_path = Path(args.input)
assert input_path.exists(), f"Input not found: {input_path}"
output_dir = Path(args.output_dir) if args.output_dir else input_path.parent
output_dir.mkdir(parents=True, exist_ok=True)
# 三个输出 manifest
out_paths: Dict[str, Path] = {
"sim_static": output_dir / "train_sim_static.jsonl",
"qa_sim": output_dir / "train_qa_sim.jsonl",
"dcase_real": output_dir / "train_dcase_real.jsonl",
}
counter = Counter()
unknown = Counter()
# 统计
with open(input_path) as f:
for line in f:
try:
d = json.loads(line)
except Exception:
continue
src = str(d.get("data_source", ""))
if src in out_paths:
counter[src] += 1
else:
unknown[src] += 1
total_known = sum(counter.values())
total_unknown = sum(unknown.values())
print("=" * 60)
print(f" Input: {input_path}")
print(f" Total lines: {total_known + total_unknown}")
print(f" Known data_sources:")
for k, v in counter.most_common():
print(f" {k:<14s}: {v:>7d} ({100*v/(total_known+total_unknown):.2f}%)")
if unknown:
print(f" UNKNOWN data_sources (will be DROPPED):")
for k, v in unknown.most_common():
print(f" {repr(k):<20s}: {v:>7d}")
print("=" * 60)
if args.dry_run:
print("[dry-run] 不写文件")
return
# 写三份
writers = {k: open(p, "w") for k, p in out_paths.items()}
try:
written = Counter()
with open(input_path) as f:
for line in f:
try:
d = json.loads(line)
except Exception:
continue
src = str(d.get("data_source", ""))
if src in writers:
writers[src].write(line)
written[src] += 1
finally:
for w in writers.values():
w.close()
print("Written:")
for k, p in out_paths.items():
print(f" {k:<14s} → {p} ({written[k]} lines)")
print("Done.")
if __name__ == "__main__":
main()
|