| """Phase D-experimental (C) — Cache extractor that FILLS assistant_text with |
| GT BELIEF descriptions instead of empty placeholders. |
| |
| Original v3 cache extracts hidden states with assistant_text = |
| <|BELIEF|> </|BELIEF|>\n × 8 frames ← empty placeholders |
| |
| This version fills each block with the GT description from |
| manifest's beliefs_per_frame field: |
| <|BELIEF|> lead vehicle drifting </|BELIEF|>\n |
| <|BELIEF|> side-street vehicle approaching </|BELIEF|>\n ... |
| |
| Then range-pools the BELIEF span (now contains actual descriptive tokens) |
| to get features that ARE visually-informed (because text content varies |
| per-frame and reflects scene description). |
| |
| Output schema matches make_cache_x_v2.py. |
| |
| Usage: |
| python tools/make_cache_gt_belief.py \ |
| --split train_9k_gtb \ |
| --manifest data/cot_corpus_v2/vlalert_x_perframe_v2_train.jsonl |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import sys |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| sys.path.insert(0, str(ROOT)) |
|
|
| |
| from tools import run_train_cot_belief_fast |
|
|
| import torch |
| from tqdm import tqdm |
| from transformers import AutoProcessor |
| from transformers.models.qwen3_vl import Qwen3VLForConditionalGeneration |
| from peft import PeftModel |
|
|
| from training.VLA.cot_belief_dataset import ( |
| BELIEF_OPEN, BELIEF_CLOSE, SYSTEM_PROMPT, USER_PROMPT |
| ) |
| from training.VLA.frame_utils import sample_frames |
|
|
| logging.basicConfig(level=logging.INFO, |
| format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("gtb_cache") |
|
|
| BELIEF_LAYERS = (20, 24, 28, 32) |
| POLICY_LAYER = 33 |
|
|
|
|
| @torch.no_grad() |
| def extract_one(model, proc, frames, beliefs, device, |
| belief_layers=BELIEF_LAYERS, policy_layer=POLICY_LAYER): |
| """Return (belief_feat [8, 10240], policy_feat [8, 2560], valid [8]). |
| |
| Uses the SAME extraction logic as make_cache_x_v2.py but with |
| BELIEF placeholders FILLED with the per-frame GT descriptions. |
| """ |
| assert len(beliefs) == 8, f"need 8 belief strings, got {len(beliefs)}" |
| |
| assistant_text = "\n".join( |
| f"{BELIEF_OPEN} {b.strip()} {BELIEF_CLOSE}" for b in beliefs) |
| user_content = [{"type": "image", "image": img} for img in frames] |
| user_content.append({"type": "text", "text": USER_PROMPT}) |
| messages = [ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, |
| {"role": "user", "content": user_content}, |
| {"role": "assistant", "content": [{"type": "text", "text": assistant_text}]}, |
| ] |
| text = proc.apply_chat_template(messages, tokenize=False, |
| add_generation_prompt=False) |
| inputs = proc(text=[text], images=[frames], return_tensors="pt", |
| padding=True, truncation=False, max_length=8192) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| out = model(**inputs, output_hidden_states=True, return_dict=True) |
| hs_tuple = out.hidden_states |
| ids = inputs["input_ids"][0] |
| attn = inputs["attention_mask"][0].bool() |
|
|
| open_id = proc.tokenizer.convert_tokens_to_ids(BELIEF_OPEN) |
| close_id = proc.tokenizer.convert_tokens_to_ids(BELIEF_CLOSE) |
| open_pos = ((ids == open_id) & attn).nonzero(as_tuple=False).flatten().tolist() |
| close_pos = ((ids == close_id) & attn).nonzero(as_tuple=False).flatten().tolist() |
| n_blocks = min(len(open_pos), len(close_pos), 8) |
|
|
| D = hs_tuple[-1].shape[-1] |
| belief_dim = D * len(belief_layers) |
| belief_feat = torch.zeros(8, belief_dim, dtype=torch.float16, device=device) |
| policy_feat = torch.zeros(8, D, dtype=torch.float16, device=device) |
| valid = torch.zeros(8, dtype=torch.bool, device=device) |
|
|
| for f, (o, c) in enumerate(zip(open_pos[:n_blocks], close_pos[:n_blocks])): |
| if c <= o + 1: |
| continue |
| |
| parts = [] |
| for L in belief_layers: |
| hs = hs_tuple[L][0, o+1:c] |
| parts.append(hs.mean(dim=0)) |
| belief_feat[f] = torch.cat(parts, dim=-1).to(torch.float16) |
| |
| policy_feat[f] = hs_tuple[policy_layer][0, c].to(torch.float16) |
| valid[f] = True |
| return belief_feat.cpu(), policy_feat.cpu(), valid.cpu() |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser(description=__doc__) |
| ap.add_argument("--split", required=True) |
| ap.add_argument("--manifest", type=Path, required=True) |
| ap.add_argument("--ckpt", type=Path, |
| default=ROOT / "checkpoints/sft_x_v3/best") |
| ap.add_argument("--base_model", type=Path, |
| default=ROOT / "models/Qwen3-VL-4B-Instruct") |
| ap.add_argument("--tag", default="sft_x_v3") |
| ap.add_argument("--out_dir", type=Path, |
| default=ROOT / "data/belief_cache_v3") |
| ap.add_argument("--limit", type=int, default=0) |
| ap.add_argument("--window", |
| choices=["legacy", "sil_wide", "obs_mid", "alr_narrow"], |
| default="legacy", |
| help="v4: pick which frame-index array to read from the " |
| "manifest ({window}_frame_indices). legacy uses the " |
| "original 'frame_indices' field (v3 behaviour).") |
| args = ap.parse_args() |
| args.out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"[load] ckpt={args.ckpt}") |
| proc = AutoProcessor.from_pretrained(str(args.ckpt)) |
| base = Qwen3VLForConditionalGeneration.from_pretrained( |
| str(args.base_model), dtype=torch.bfloat16, device_map={"": device}, |
| attn_implementation="sdpa") |
| base.resize_token_embeddings(len(proc.tokenizer)) |
| model = PeftModel.from_pretrained(base, str(args.ckpt)).eval() |
|
|
| logger.info(f"[load] manifest={args.manifest} window={args.window}") |
| fi_field = "frame_indices" if args.window == "legacy" \ |
| else f"{args.window.split('_')[0]}_frame_indices" |
| logger.info(f" reading frame indices from field: {fi_field}") |
| records = [] |
| with args.manifest.open() as f: |
| for ln in f: |
| if not ln.strip(): continue |
| obj = json.loads(ln) |
| if not obj.get("beliefs_per_frame") or len(obj["beliefs_per_frame"]) != 8: |
| continue |
| if fi_field not in obj: |
| continue |
| records.append(obj) |
| if args.limit > 0: |
| records = records[:args.limit] |
| N = len(records) |
| logger.info(f" N={N} (with GT beliefs_per_frame + {fi_field})") |
|
|
| belief_dim = 2560 * len(BELIEF_LAYERS) |
| out_belief = torch.zeros(N, 8, belief_dim, dtype=torch.float16) |
| out_policy = torch.zeros(N, 8, 2560, dtype=torch.float16) |
| out_valid = torch.zeros(N, 8, dtype=torch.bool) |
| out_actions = torch.zeros(N, 8, dtype=torch.long) |
| out_danger = torch.zeros(N, 8, dtype=torch.float32) |
| out_tta = torch.zeros(N, 8, dtype=torch.float32) |
| out_tick_action = torch.zeros(N, dtype=torch.long) |
| out_tick_tta = torch.full((N,), -1.0) |
| |
| out_prev_action = torch.full((N,), 3, dtype=torch.long) |
| out_oracle_window = torch.zeros(N, dtype=torch.long) |
| out_boundary = torch.zeros(N, dtype=torch.bool) |
| out_category, out_source, out_video_id, out_ids = [], [], [], [] |
| action_map = {"SILENT": 0, "OBSERVE": 1, "ALERT": 2} |
| failed = 0 |
|
|
| for i, r in enumerate(tqdm(records, desc="gtb_cache", ncols=80)): |
| try: |
| frames = sample_frames(Path(r["video_path"]), |
| frame_indices=r[fi_field], |
| resize_short=336) |
| except Exception: |
| failed += 1; continue |
| bf, pf, v = extract_one(model, proc, frames, |
| r["beliefs_per_frame"], device) |
| out_belief[i] = bf |
| out_policy[i] = pf |
| out_valid[i] = v |
| actions_pf = r.get("actions_per_frame", ["SILENT"]*8) |
| out_actions[i] = torch.tensor( |
| [action_map.get(a, 0) for a in actions_pf], dtype=torch.long) |
| out_danger[i] = torch.tensor(r.get("danger_per_frame", [0.0]*8)) |
| out_tta[i] = torch.tensor(r.get("tta_per_frame", [-1.0]*8)) |
| out_tick_action[i] = action_map.get(r.get("tick_action", "SILENT"), 0) |
| out_tick_tta[i] = float(r.get("tick_tta_raw", -1.0)) |
| |
| out_prev_action[i] = int(r.get("prev_action", 3)) |
| out_oracle_window[i] = int(r.get("oracle_window", 1)) |
| out_boundary[i] = bool(r.get("boundary", False)) |
| out_category.append(r.get("category", "")) |
| out_source.append(r.get("source", "")) |
| out_video_id.append(r.get("video_id", "")) |
| out_ids.append(r.get("id", r.get("video_id", ""))) |
|
|
| out_path = args.out_dir / f"{args.tag}__{args.split}.pt" |
| cache = { |
| "ids": out_ids, |
| "belief_content": out_belief, |
| "policy_position": out_policy, |
| "valid_frames": out_valid, |
| "actions_pf": out_actions, |
| "danger_pf": out_danger, |
| "tta_pf": out_tta, |
| "tick_action": out_tick_action, |
| "tick_tta_raw": out_tick_tta, |
| "prev_action": out_prev_action, |
| "oracle_window": out_oracle_window, |
| "boundary": out_boundary, |
| "window": args.window, |
| "category": out_category, |
| "source": out_source, |
| "video_id": out_video_id, |
| "schema": "vlalert_x_v4_gt_belief_fill", |
| "belief_layers": list(BELIEF_LAYERS), |
| "policy_layer": POLICY_LAYER, |
| "ckpt": str(args.ckpt), |
| } |
| torch.save(cache, out_path) |
| logger.info(f"[save] {out_path} failed={failed}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|