File size: 3,571 Bytes
29615e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python3
"""拆分 unified_spatial_foa_fsd63_all/train.jsonl 按 data_source 字段分成三份。

供 v13_C 实验使用:real replication 需要把 dcase_real 单独 manifest,以便
train_manifest_replication=(1, 1, 6) 让 real 占比 6% → ~25%。

用法:
    python scripts/split_unified_train_by_source.py \\
        --input /apdcephfs_cq12/.../unified_spatial_foa_fsd63_all/train.jsonl \\
        --output-dir /apdcephfs_cq12/.../unified_spatial_foa_fsd63_all

    # dry-run 只统计不写文件
    python scripts/split_unified_train_by_source.py \\
        --input /apdcephfs_cq12/.../train.jsonl --dry-run
"""
from __future__ import annotations

import argparse
import json
from collections import Counter
from pathlib import Path
from typing import Dict, List


DEFAULT_INPUT = (
    "/apdcephfs_cq12/share_302080740/user/schmittzhu/data/"
    "unified_spatial_foa_fsd63_all/train.jsonl"
)


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--input", default=DEFAULT_INPUT)
    p.add_argument(
        "--output-dir",
        default=None,
        help="默认 = 输入文件所在目录",
    )
    p.add_argument(
        "--dry-run",
        action="store_true",
        help="只统计不写文件",
    )
    return p.parse_args()


def main() -> None:
    args = parse_args()
    input_path = Path(args.input)
    assert input_path.exists(), f"Input not found: {input_path}"

    output_dir = Path(args.output_dir) if args.output_dir else input_path.parent
    output_dir.mkdir(parents=True, exist_ok=True)

    # 三个输出 manifest
    out_paths: Dict[str, Path] = {
        "sim_static": output_dir / "train_sim_static.jsonl",
        "qa_sim": output_dir / "train_qa_sim.jsonl",
        "dcase_real": output_dir / "train_dcase_real.jsonl",
    }
    counter = Counter()
    unknown = Counter()

    # 统计
    with open(input_path) as f:
        for line in f:
            try:
                d = json.loads(line)
            except Exception:
                continue
            src = str(d.get("data_source", ""))
            if src in out_paths:
                counter[src] += 1
            else:
                unknown[src] += 1
    total_known = sum(counter.values())
    total_unknown = sum(unknown.values())

    print("=" * 60)
    print(f"  Input: {input_path}")
    print(f"  Total lines: {total_known + total_unknown}")
    print(f"  Known data_sources:")
    for k, v in counter.most_common():
        print(f"    {k:<14s}: {v:>7d} ({100*v/(total_known+total_unknown):.2f}%)")
    if unknown:
        print(f"  UNKNOWN data_sources (will be DROPPED):")
        for k, v in unknown.most_common():
            print(f"    {repr(k):<20s}: {v:>7d}")
    print("=" * 60)

    if args.dry_run:
        print("[dry-run] 不写文件")
        return

    # 写三份
    writers = {k: open(p, "w") for k, p in out_paths.items()}
    try:
        written = Counter()
        with open(input_path) as f:
            for line in f:
                try:
                    d = json.loads(line)
                except Exception:
                    continue
                src = str(d.get("data_source", ""))
                if src in writers:
                    writers[src].write(line)
                    written[src] += 1
    finally:
        for w in writers.values():
            w.close()

    print("Written:")
    for k, p in out_paths.items():
        print(f"  {k:<14s}{p}  ({written[k]} lines)")
    print("Done.")


if __name__ == "__main__":
    main()