Spatial-BEATs / scripts /split_unified_train_by_source.py
dieKarotte's picture
Add files using upload-large-folder tool
29615e9 verified
#!/usr/bin/env python3
"""拆分 unified_spatial_foa_fsd63_all/train.jsonl 按 data_source 字段分成三份。
供 v13_C 实验使用:real replication 需要把 dcase_real 单独 manifest,以便
train_manifest_replication=(1, 1, 6) 让 real 占比 6% → ~25%。
用法:
python scripts/split_unified_train_by_source.py \\
--input /apdcephfs_cq12/.../unified_spatial_foa_fsd63_all/train.jsonl \\
--output-dir /apdcephfs_cq12/.../unified_spatial_foa_fsd63_all
# dry-run 只统计不写文件
python scripts/split_unified_train_by_source.py \\
--input /apdcephfs_cq12/.../train.jsonl --dry-run
"""
from __future__ import annotations
import argparse
import json
from collections import Counter
from pathlib import Path
from typing import Dict, List
DEFAULT_INPUT = (
"/apdcephfs_cq12/share_302080740/user/schmittzhu/data/"
"unified_spatial_foa_fsd63_all/train.jsonl"
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--input", default=DEFAULT_INPUT)
p.add_argument(
"--output-dir",
default=None,
help="默认 = 输入文件所在目录",
)
p.add_argument(
"--dry-run",
action="store_true",
help="只统计不写文件",
)
return p.parse_args()
def main() -> None:
args = parse_args()
input_path = Path(args.input)
assert input_path.exists(), f"Input not found: {input_path}"
output_dir = Path(args.output_dir) if args.output_dir else input_path.parent
output_dir.mkdir(parents=True, exist_ok=True)
# 三个输出 manifest
out_paths: Dict[str, Path] = {
"sim_static": output_dir / "train_sim_static.jsonl",
"qa_sim": output_dir / "train_qa_sim.jsonl",
"dcase_real": output_dir / "train_dcase_real.jsonl",
}
counter = Counter()
unknown = Counter()
# 统计
with open(input_path) as f:
for line in f:
try:
d = json.loads(line)
except Exception:
continue
src = str(d.get("data_source", ""))
if src in out_paths:
counter[src] += 1
else:
unknown[src] += 1
total_known = sum(counter.values())
total_unknown = sum(unknown.values())
print("=" * 60)
print(f" Input: {input_path}")
print(f" Total lines: {total_known + total_unknown}")
print(f" Known data_sources:")
for k, v in counter.most_common():
print(f" {k:<14s}: {v:>7d} ({100*v/(total_known+total_unknown):.2f}%)")
if unknown:
print(f" UNKNOWN data_sources (will be DROPPED):")
for k, v in unknown.most_common():
print(f" {repr(k):<20s}: {v:>7d}")
print("=" * 60)
if args.dry_run:
print("[dry-run] 不写文件")
return
# 写三份
writers = {k: open(p, "w") for k, p in out_paths.items()}
try:
written = Counter()
with open(input_path) as f:
for line in f:
try:
d = json.loads(line)
except Exception:
continue
src = str(d.get("data_source", ""))
if src in writers:
writers[src].write(line)
written[src] += 1
finally:
for w in writers.values():
w.close()
print("Written:")
for k, p in out_paths.items():
print(f" {k:<14s}{p} ({written[k]} lines)")
print("Done.")
if __name__ == "__main__":
main()