VLAlert / tools /build_v6_training_data.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
5.79 kB
#!/usr/bin/env python
"""Build v6 training data: [Analysis] → [Safety Assessment] format.
Reads v5_sft_{train,val}.jsonl and produces v6 versions with:
1. [Analysis] reasoning block (per-frame safety analysis)
2. [Safety Assessment] belief+action block (structured <|BELIEF|> tokens)
3. Mixed 1-frame and 8-frame samples
Usage:
python tools/build_v6_training_data.py
"""
from __future__ import annotations
import json, random, logging
from pathlib import Path
from collections import Counter
ROOT = Path("PROJECT_ROOT")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
log = logging.getLogger("v6")
BELIEF_OPEN = "<|BELIEF|>"
BELIEF_CLOSE = "</|BELIEF|>"
ACTION_MAP = {"SILENT": "<|SILENT|>", "OBSERVE": "<|OBSERVE|>", "ALERT": "<|ALERT|>"}
SINGLE_FRAME_RATIO = 0.2
def build_analysis_block(record: dict, n_frames: int = 8) -> str:
"""Build the [Analysis] reasoning block."""
beliefs = record.get("beliefs_per_frame", [])
actions = record.get("actions_per_frame", [])
rationale = record.get("one_sentence_rationale", "")
source = record.get("source", "")
category = record.get("category", "")
hazard = record.get("hazard_category", "")
lines = ["[Analysis]"]
if rationale:
lines.append(rationale)
lines.append("")
for i in range(min(n_frames, len(beliefs))):
b = (beliefs[i] or "").strip().replace("\n", " ")
a = actions[i] if i < len(actions) else "SILENT"
if not b:
b = f"No notable safety cue at frame {i+1}"
if a == "ALERT":
prefix = "DANGER:"
elif a == "OBSERVE":
prefix = "CAUTION:"
else:
prefix = ""
frame_line = f"Frame {i+1}: {prefix + ' ' if prefix else ''}{b}"
lines.append(frame_line)
return "\n".join(lines)
def build_assessment_block(record: dict, n_frames: int = 8) -> str:
"""Build the [Safety Assessment] belief+action block."""
beliefs = record.get("beliefs_per_frame", [])
actions = record.get("actions_per_frame", [])
lines = ["", "[Safety Assessment]"]
for i in range(min(n_frames, len(beliefs))):
b = (beliefs[i] or "").strip().replace("\n", " ")
b = " ".join(b.split()[:25])
a = actions[i] if i < len(actions) else "SILENT"
tok = ACTION_MAP.get(a, ACTION_MAP["SILENT"])
lines.append(f"{BELIEF_OPEN} {b} {BELIEF_CLOSE} {tok}")
return "\n".join(lines)
def build_assistant_v6(record: dict, n_frames: int = 8) -> str:
"""Build complete v6 assistant response."""
analysis = build_analysis_block(record, n_frames)
assessment = build_assessment_block(record, n_frames)
return analysis + assessment
def make_single_frame_record(record: dict) -> dict | None:
"""Create a 1-frame version by sampling one frame from the 8-frame record."""
beliefs = record.get("beliefs_per_frame", [])
actions = record.get("actions_per_frame", [])
frames = record.get("frame_indices", [])
if len(beliefs) < 1 or len(frames) < 1:
return None
# Prefer frames with non-SILENT action for training diversity
non_silent = [i for i, a in enumerate(actions) if a != "SILENT"]
if non_silent and random.random() < 0.5:
idx = random.choice(non_silent)
else:
idx = random.randint(0, min(len(beliefs), len(frames)) - 1)
new = dict(record)
new["id"] = record["id"] + f"_1f{idx}"
new["frame_indices"] = [frames[idx]]
new["beliefs_per_frame"] = [beliefs[idx]]
new["actions_per_frame"] = [actions[idx]]
new["danger_per_frame"] = [record.get("danger_per_frame", [0.0] * 8)[idx]]
new["tta_per_frame"] = [record.get("tta_per_frame", [10.0] * 8)[idx]]
new["n_frames"] = 1
return new
def process_split(input_path: Path, output_path: Path, add_single_frame: bool = True):
"""Process one split (train or val)."""
lines = input_path.read_text().strip().split("\n")
log.info(f"Input: {input_path.name}{len(lines)} records")
output_records = []
stats = Counter()
for l in lines:
record = json.loads(l)
# 8-frame record
record["n_frames"] = 8
record["assistant_v6"] = build_assistant_v6(record, 8)
output_records.append(record)
stats["8frame"] += 1
bsrc = record.get("belief_source", "auto_generated")
stats[f"src_{bsrc}"] += 1
# 1-frame record (sampled subset)
if add_single_frame and random.random() < SINGLE_FRAME_RATIO:
single = make_single_frame_record(record)
if single:
single["assistant_v6"] = build_assistant_v6(single, 1)
output_records.append(single)
stats["1frame"] += 1
random.shuffle(output_records)
with open(output_path, "w") as f:
for r in output_records:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
log.info(f"Output: {output_path.name}{len(output_records)} records")
log.info(f" Stats: {dict(stats)}")
# Show examples
for r in output_records[:3]:
n = r.get("n_frames", 8)
log.info(f"\n Example ({n}-frame, {r['source']}, {r.get('belief_source','?')}):")
asst = r["assistant_v6"]
for line in asst.split("\n")[:6]:
log.info(f" {line[:80]}")
log.info(f" ...")
def main():
random.seed(42)
for split in ["train", "val"]:
inp = ROOT / f"data/cot_corpus_v3/v5_sft_{split}.jsonl"
out = ROOT / f"data/cot_corpus_v3/v6_sft_{split}.jsonl"
if not inp.exists():
log.warning(f" {inp} not found, skip")
continue
process_split(inp, out, add_single_frame=(split == "train"))
log.info("\nDone!")
if __name__ == "__main__":
main()