"""Phase D-experimental (C) — Cache extractor that FILLS assistant_text with GT BELIEF descriptions instead of empty placeholders. Original v3 cache extracts hidden states with assistant_text = <|BELIEF|> \n × 8 frames ← empty placeholders This version fills each block with the GT description from manifest's beliefs_per_frame field: <|BELIEF|> lead vehicle drifting \n <|BELIEF|> side-street vehicle approaching \n ... Then range-pools the BELIEF span (now contains actual descriptive tokens) to get features that ARE visually-informed (because text content varies per-frame and reflects scene description). Output schema matches make_cache_x_v2.py. Usage: python tools/make_cache_gt_belief.py \ --split train_9k_gtb \ --manifest data/cot_corpus_v2/vlalert_x_perframe_v2_train.jsonl """ from __future__ import annotations import argparse import json import logging import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) # Conv3d→Linear patch from tools import run_train_cot_belief_fast # noqa: F401 import torch from tqdm import tqdm from transformers import AutoProcessor from transformers.models.qwen3_vl import Qwen3VLForConditionalGeneration from peft import PeftModel from training.VLA.cot_belief_dataset import ( BELIEF_OPEN, BELIEF_CLOSE, SYSTEM_PROMPT, USER_PROMPT ) from training.VLA.frame_utils import sample_frames logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("gtb_cache") BELIEF_LAYERS = (20, 24, 28, 32) POLICY_LAYER = 33 @torch.no_grad() def extract_one(model, proc, frames, beliefs, device, belief_layers=BELIEF_LAYERS, policy_layer=POLICY_LAYER): """Return (belief_feat [8, 10240], policy_feat [8, 2560], valid [8]). Uses the SAME extraction logic as make_cache_x_v2.py but with BELIEF placeholders FILLED with the per-frame GT descriptions. """ assert len(beliefs) == 8, f"need 8 belief strings, got {len(beliefs)}" # Fill the placeholder with GT text per frame assistant_text = "\n".join( f"{BELIEF_OPEN} {b.strip()} {BELIEF_CLOSE}" for b in beliefs) user_content = [{"type": "image", "image": img} for img in frames] user_content.append({"type": "text", "text": USER_PROMPT}) messages = [ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, {"role": "user", "content": user_content}, {"role": "assistant", "content": [{"type": "text", "text": assistant_text}]}, ] text = proc.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) inputs = proc(text=[text], images=[frames], return_tensors="pt", padding=True, truncation=False, max_length=8192) inputs = {k: v.to(device) for k, v in inputs.items()} out = model(**inputs, output_hidden_states=True, return_dict=True) hs_tuple = out.hidden_states # tuple of [1, T, D] ids = inputs["input_ids"][0] attn = inputs["attention_mask"][0].bool() open_id = proc.tokenizer.convert_tokens_to_ids(BELIEF_OPEN) close_id = proc.tokenizer.convert_tokens_to_ids(BELIEF_CLOSE) open_pos = ((ids == open_id) & attn).nonzero(as_tuple=False).flatten().tolist() close_pos = ((ids == close_id) & attn).nonzero(as_tuple=False).flatten().tolist() n_blocks = min(len(open_pos), len(close_pos), 8) D = hs_tuple[-1].shape[-1] belief_dim = D * len(belief_layers) belief_feat = torch.zeros(8, belief_dim, dtype=torch.float16, device=device) policy_feat = torch.zeros(8, D, dtype=torch.float16, device=device) valid = torch.zeros(8, dtype=torch.bool, device=device) for f, (o, c) in enumerate(zip(open_pos[:n_blocks], close_pos[:n_blocks])): if c <= o + 1: continue # Range pool over BELIEF span content (now ACTUALLY has descriptive text) parts = [] for L in belief_layers: hs = hs_tuple[L][0, o+1:c] parts.append(hs.mean(dim=0)) belief_feat[f] = torch.cat(parts, dim=-1).to(torch.float16) # POLICY at closing token policy_feat[f] = hs_tuple[policy_layer][0, c].to(torch.float16) valid[f] = True return belief_feat.cpu(), policy_feat.cpu(), valid.cpu() def main(): ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--split", required=True) ap.add_argument("--manifest", type=Path, required=True) ap.add_argument("--ckpt", type=Path, default=ROOT / "checkpoints/sft_x_v3/best") ap.add_argument("--base_model", type=Path, default=ROOT / "models/Qwen3-VL-4B-Instruct") ap.add_argument("--tag", default="sft_x_v3") ap.add_argument("--out_dir", type=Path, default=ROOT / "data/belief_cache_v3") ap.add_argument("--limit", type=int, default=0) ap.add_argument("--window", choices=["legacy", "sil_wide", "obs_mid", "alr_narrow"], default="legacy", help="v4: pick which frame-index array to read from the " "manifest ({window}_frame_indices). legacy uses the " "original 'frame_indices' field (v3 behaviour).") args = ap.parse_args() args.out_dir.mkdir(parents=True, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"[load] ckpt={args.ckpt}") proc = AutoProcessor.from_pretrained(str(args.ckpt)) base = Qwen3VLForConditionalGeneration.from_pretrained( str(args.base_model), dtype=torch.bfloat16, device_map={"": device}, attn_implementation="sdpa") base.resize_token_embeddings(len(proc.tokenizer)) model = PeftModel.from_pretrained(base, str(args.ckpt)).eval() logger.info(f"[load] manifest={args.manifest} window={args.window}") fi_field = "frame_indices" if args.window == "legacy" \ else f"{args.window.split('_')[0]}_frame_indices" logger.info(f" reading frame indices from field: {fi_field}") records = [] with args.manifest.open() as f: for ln in f: if not ln.strip(): continue obj = json.loads(ln) if not obj.get("beliefs_per_frame") or len(obj["beliefs_per_frame"]) != 8: continue if fi_field not in obj: continue records.append(obj) if args.limit > 0: records = records[:args.limit] N = len(records) logger.info(f" N={N} (with GT beliefs_per_frame + {fi_field})") belief_dim = 2560 * len(BELIEF_LAYERS) out_belief = torch.zeros(N, 8, belief_dim, dtype=torch.float16) out_policy = torch.zeros(N, 8, 2560, dtype=torch.float16) out_valid = torch.zeros(N, 8, dtype=torch.bool) out_actions = torch.zeros(N, 8, dtype=torch.long) out_danger = torch.zeros(N, 8, dtype=torch.float32) out_tta = torch.zeros(N, 8, dtype=torch.float32) out_tick_action = torch.zeros(N, dtype=torch.long) out_tick_tta = torch.full((N,), -1.0) # v4 additions out_prev_action = torch.full((N,), 3, dtype=torch.long) out_oracle_window = torch.zeros(N, dtype=torch.long) out_boundary = torch.zeros(N, dtype=torch.bool) out_category, out_source, out_video_id, out_ids = [], [], [], [] action_map = {"SILENT": 0, "OBSERVE": 1, "ALERT": 2} failed = 0 for i, r in enumerate(tqdm(records, desc="gtb_cache", ncols=80)): try: frames = sample_frames(Path(r["video_path"]), frame_indices=r[fi_field], resize_short=336) except Exception: failed += 1; continue bf, pf, v = extract_one(model, proc, frames, r["beliefs_per_frame"], device) out_belief[i] = bf out_policy[i] = pf out_valid[i] = v actions_pf = r.get("actions_per_frame", ["SILENT"]*8) out_actions[i] = torch.tensor( [action_map.get(a, 0) for a in actions_pf], dtype=torch.long) out_danger[i] = torch.tensor(r.get("danger_per_frame", [0.0]*8)) out_tta[i] = torch.tensor(r.get("tta_per_frame", [-1.0]*8)) out_tick_action[i] = action_map.get(r.get("tick_action", "SILENT"), 0) out_tick_tta[i] = float(r.get("tick_tta_raw", -1.0)) # v4 fields (read if present, else default) out_prev_action[i] = int(r.get("prev_action", 3)) out_oracle_window[i] = int(r.get("oracle_window", 1)) out_boundary[i] = bool(r.get("boundary", False)) out_category.append(r.get("category", "")) out_source.append(r.get("source", "")) out_video_id.append(r.get("video_id", "")) out_ids.append(r.get("id", r.get("video_id", ""))) out_path = args.out_dir / f"{args.tag}__{args.split}.pt" cache = { "ids": out_ids, "belief_content": out_belief, "policy_position": out_policy, "valid_frames": out_valid, "actions_pf": out_actions, "danger_pf": out_danger, "tta_pf": out_tta, "tick_action": out_tick_action, "tick_tta_raw": out_tick_tta, "prev_action": out_prev_action, "oracle_window": out_oracle_window, "boundary": out_boundary, "window": args.window, "category": out_category, "source": out_source, "video_id": out_video_id, "schema": "vlalert_x_v4_gt_belief_fill", "belief_layers": list(BELIEF_LAYERS), "policy_layer": POLICY_LAYER, "ckpt": str(args.ckpt), } torch.save(cache, out_path) logger.info(f"[save] {out_path} failed={failed}") if __name__ == "__main__": main()