"""VLAlert-X belief cache extractor — multi-layer + action-pool, per-frame. Reads a cot_belief_dataset-format JSONL manifest (e.g. data/cot_corpus_v2/vlalert_x_sft.jsonl), forwards each clip through the SFT'd Qwen3-VL-4B + LoRA, and saves per-frame belief vectors at the action-token positions, with the last `n_layers` transformer layers concatenated. Output schema (mirrors `data/belief_cache_perframe_qwen3vl4b/*.pt`): { "beliefs_frame": [N, 8, n_layers*D] fp16 (D=2560 → 10240 if L=4) "valid_frames": [N, 8] bool "ids": list[str] (clip_id per row) "category": list[str] (ego_positive/safe_neg) "source": list[str] (nexar/dada/...) "action_per_frame": list[list[str]] (oracle, from manifest) "tta_raw": [N] float (clip-level TTA) "schema": "vlalert_x_belief_v1" "n_layers": int "pool_mode": str } The action-pool mode finds the per-frame action token positions in the assistant string and reads the hidden state at each. Falls back to BELIEF-open positions if action_pool returns wrong number of tokens. Usage (single pass, single manifest): python tools/make_belief_cache_x.py \ --ckpt checkpoints/sft_x/best \ --manifest data/cot_corpus_v2/vlalert_x_sft.jsonl \ --out data/belief_cache_x/sft_x__action.pt \ --n_layers 4 --pool_mode action Designed to be called by tools/extract_3window_cache.py, once per {split, window} combination. """ from __future__ import annotations # Apply Conv3d→Linear patch BEFORE any model load import sys; sys.path.insert(0, ".") from tools import run_train_cot_belief_fast # noqa: F401 import argparse import json import logging import time from pathlib import Path from typing import Dict, List, Optional import torch from tqdm import tqdm ROOT = Path(__file__).resolve().parents[1] logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("make_belief_cache_x") def extract_per_frame_beliefs( ckpt_dir: Path, base_model: Path, manifest_path: Path, out_path: Path, n_frames: int = 8, n_layers: int = 4, pool_mode: str = "action", random_span_seed: int = 0, random_span_len: int = 25, limit: int = 0, ): """Extract per-frame belief cache for VLAlert-X.""" if out_path.exists(): logger.info(f"[skip] {out_path} exists; reuse") return from transformers import AutoProcessor, Qwen3VLForConditionalGeneration from peft import PeftModel from training.VLA.cot_belief_dataset import ( ALL_SPECIAL, BELIEF_OPEN, BELIEF_CLOSE, ACTION_ALERT, ACTION_OBSERVE, ACTION_SILENT, build_chat, format_assistant, _resolve_actions, ) from training.VLA.frame_utils import sample_frames logger.info(f"[load] base_model={base_model} ckpt={ckpt_dir}") logger.info(f" n_layers={n_layers} pool_mode={pool_mode}") processor = AutoProcessor.from_pretrained(base_model, trust_remote_code=True) processor.tokenizer.add_special_tokens({"additional_special_tokens": ALL_SPECIAL}) model = Qwen3VLForConditionalGeneration.from_pretrained( base_model, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True) model.resize_token_embeddings(len(processor.tokenizer)) if (ckpt_dir / "adapter_config.json").exists(): model = PeftModel.from_pretrained(model, ckpt_dir) model.eval() tok = processor.tokenizer belief_open_id = tok.convert_tokens_to_ids(BELIEF_OPEN) belief_close_id = tok.convert_tokens_to_ids(BELIEF_CLOSE) action_ids = {tok.convert_tokens_to_ids(t) for t in (ACTION_SILENT, ACTION_OBSERVE, ACTION_ALERT)} # ── load manifest (allow stub-CoT records for val/policy_labels) ── def _ensure_record(r: Dict) -> Optional[Dict]: """If record lacks cot/belief, synthesise a stub so the assistant string still has 8 BELIEF blocks. Action labels are derived from whatever the manifest provides (or all-SILENT).""" if not r.get("video_path"): return None if r.get("cot") and r.get("belief", {}).get("frame_indices"): return r # stub mode action_lbl = r.get("action_label", 0) clip_action = {0: "SILENT", 1: "OBSERVE", 2: "ALERT"}.get(int(action_lbl), "SILENT") actions_pf = r.get("actions_per_frame") or [clip_action] * n_frames if len(actions_pf) != n_frames: actions_pf = (actions_pf + [clip_action] * n_frames)[:n_frames] frame_idx = (r.get("frame_indices") or (r.get("belief") or {}).get("frame_indices")) if not frame_idx: return None return { "id": r.get("id") or r.get("video_id", ""), "video_path": r["video_path"], "category": r.get("category", ""), "source": r.get("source", ""), "tta_raw": r.get("tta_raw", -1.0), "cot": { "scene": "(n/a)", "critical_objects": [], "threat_analysis": "(n/a)", }, "belief": { "action": clip_action, "actions_per_frame": actions_pf, "frame_indices": frame_idx, }, } records: List[Dict] = [] n_stub = 0 with open(manifest_path) as f: for ln in f: ln = ln.strip() if not ln: continue try: r = json.loads(ln) rec = _ensure_record(r) if rec is not None: if not r.get("cot"): n_stub += 1 records.append(rec) except Exception: pass if limit > 0: records = records[:limit] logger.info(f"[load] {manifest_path} n={len(records)} stub_cot={n_stub}") # ── allocate output tensors ───────────────────────────────────── # We don't know D until first forward; allocate after first sample out_beliefs: Optional[torch.Tensor] = None out_valid = torch.zeros(len(records), n_frames, dtype=torch.bool) ids_list, cat_list, src_list, actions_list = [], [], [], [] tta_list = torch.zeros(len(records), dtype=torch.float32) n_failed = 0 n_pool_fallback = 0 t0 = time.time() for i, rec in enumerate(tqdm(records, ncols=80, desc="cache_x")): try: video_path = rec["video_path"] frame_idx = rec["belief"].get("frame_indices") frames = sample_frames(video_path, n_frames=n_frames, resize_short=336, frame_indices=frame_idx) actions = _resolve_actions(rec["belief"], n_frames) assistant_text = format_assistant(rec["cot"], actions) full_msgs = build_chat(frames, assistant_text=assistant_text) full_text = processor.apply_chat_template( full_msgs, tokenize=False, add_generation_prompt=False) inputs = processor(text=[full_text], images=[frames], return_tensors="pt", padding=False, truncation=True, max_length=4096) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): out = model(**inputs, output_hidden_states=True, return_dict=True) # multi-layer concat: [T, n_layers * D] if n_layers == 1: hs = out.hidden_states[-1][0] else: hs_list = [out.hidden_states[k][0] for k in range(-n_layers, 0)] hs = torch.cat(hs_list, dim=-1) ids_t = inputs["input_ids"][0] T_total, D_full = hs.shape # find per-frame pool positions if pool_mode == "action": # one action token per frame (in causal order) pos_list = [int(p) for p, t in enumerate(ids_t.tolist()) if t in action_ids] elif pool_mode == "open": pos_list = (ids_t == belief_open_id).nonzero( as_tuple=False).flatten().tolist() elif pool_mode == "range": opens = (ids_t == belief_open_id).nonzero( as_tuple=False).flatten().tolist() closes = (ids_t == belief_close_id).nonzero( as_tuple=False).flatten().tolist() # group into per-frame mean ranges pos_list = [] # not used; we pool per-range below elif pool_mode == "token_mean": # Format-agnostic baseline: mean over ALL valid (non-image, non-pad) # tokens of the assistant response. Replicated across n_frames so # the downstream tensor shape matches V0. pos_list = [] elif pool_mode == "random_span": # Control baseline: same span length as BELIEF (default 25 tokens) # but at random positions in the response. Same per-frame structure # as V0 (n_frames independent random spans). pos_list = [] else: raise ValueError(f"pool_mode={pool_mode}") # Lazy-allocate output tensor if out_beliefs is None: out_beliefs = torch.zeros(len(records), n_frames, D_full, dtype=torch.float16) # ── case 1: per-position single-vector pool ── if pool_mode in ("action", "open") and len(pos_list) >= 1: # take first n_frames positions use_pos = pos_list[:n_frames] if len(use_pos) < n_frames: n_pool_fallback += 1 for f, p in enumerate(use_pos): out_beliefs[i, f] = hs[p].float().to(torch.float16).cpu() out_valid[i, f] = True # ── case 2: range pool — mean over each <|BELIEF|>... ── elif pool_mode == "range" and len(opens) >= 1 and len(closes) >= 1: pairs = list(zip(opens[:n_frames], closes[:n_frames])) for f, (o, c) in enumerate(pairs): if c > o: out_beliefs[i, f] = hs[o:c+1].mean(dim=0).float().to( torch.float16).cpu() out_valid[i, f] = True # ── case 3 (V1): token-mean pool — mean over ALL response tokens ── elif pool_mode == "token_mean": # Find the assistant-response span: from first BELIEF-open to last # token. This excludes the user prompt and image tokens. opens_local = (ids_t == belief_open_id).nonzero( as_tuple=False).flatten().tolist() resp_start = opens_local[0] if opens_local else max(0, T_total - 200) pooled = hs[resp_start:].mean(dim=0) for f in range(n_frames): out_beliefs[i, f] = pooled.float().to(torch.float16).cpu() out_valid[i, f] = True # ── case 4 (V4): random-span pool — same-length spans at random positions ── elif pool_mode == "random_span": # Use deterministic per-sample RNG so the cache is reproducible. import random as _rnd rng = _rnd.Random(int(random_span_seed) * 100003 + i) # Estimate span length from actual BELIEF spans on this sample opens_local = (ids_t == belief_open_id).nonzero( as_tuple=False).flatten().tolist() closes_local = (ids_t == belief_close_id).nonzero( as_tuple=False).flatten().tolist() if opens_local and closes_local and len(opens_local) >= 1: span_lens = [c - o for o, c in zip(opens_local, closes_local) if c > o] L_span = max(3, int(round(sum(span_lens) / max(len(span_lens), 1)))) else: L_span = int(random_span_len) resp_start = opens_local[0] if opens_local else max(0, T_total - 200) resp_end = T_total if resp_end - resp_start <= L_span: # response too short — just mean the available range pooled = hs[resp_start:resp_end].mean(dim=0) for f in range(n_frames): out_beliefs[i, f] = pooled.float().to(torch.float16).cpu() out_valid[i, f] = True else: for f in range(n_frames): start = rng.randint(resp_start, resp_end - L_span) out_beliefs[i, f] = hs[start:start + L_span].mean(dim=0).float().to( torch.float16).cpu() out_valid[i, f] = True else: # fallback: mean-pool last 64 tokens, replicate across frames pooled = hs[-64:].mean(dim=0) for f in range(n_frames): out_beliefs[i, f] = pooled.float().to(torch.float16).cpu() # leave valid_frames = False n_pool_fallback += 1 ids_list.append(rec.get("id", str(i))) cat_list.append(rec.get("category", "")) src_list.append(rec.get("source", "")) actions_list.append(actions) tta_list[i] = float(rec.get("tta_raw", -1.0)) except Exception as e: n_failed += 1 logger.warning(f"[skip] {rec.get('id')}: {e}") ids_list.append(rec.get("id", str(i))) cat_list.append(rec.get("category", "")) src_list.append(rec.get("source", "")) actions_list.append([]) continue if out_beliefs is None: raise RuntimeError("no successful extractions") out_dict = { "beliefs_frame": out_beliefs, "valid_frames": out_valid, "ids": ids_list, "category": cat_list, "source": src_list, "action_per_frame": actions_list, "tta_raw": tta_list, "schema": "vlalert_x_belief_v1", "n_layers": n_layers, "pool_mode": pool_mode, "belief_dim": out_beliefs.shape[-1], "ckpt": str(ckpt_dir), } out_path.parent.mkdir(parents=True, exist_ok=True) torch.save(out_dict, out_path) dt = time.time() - t0 logger.info(f"[save] {out_path}") logger.info(f" shape={out_beliefs.shape} failed={n_failed} " f"fallback={n_pool_fallback} elapsed={dt:.0f}s") def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", type=Path, required=True) ap.add_argument("--base_model", type=Path, default=ROOT / "models/Qwen3-VL-4B-Instruct") ap.add_argument("--manifest", type=Path, required=True) ap.add_argument("--out", type=Path, required=True) ap.add_argument("--n_frames", type=int, default=8) ap.add_argument("--n_layers", type=int, default=4) ap.add_argument("--random_span_seed", type=int, default=0, help="RNG seed for --pool_mode random_span (deterministic per-sample)") ap.add_argument("--random_span_len", type=int, default=25, help="fallback span length for --pool_mode random_span when " "no BELIEF tags found on a sample") ap.add_argument("--pool_mode", choices=["open", "range", "action", "token_mean", "random_span"], default="action") ap.add_argument("--limit", type=int, default=0, help="If >0, truncate manifest to first N rows (smoke test)") args = ap.parse_args() extract_per_frame_beliefs( args.ckpt, args.base_model, args.manifest, args.out, n_frames=args.n_frames, n_layers=args.n_layers, pool_mode=args.pool_mode, random_span_seed=args.random_span_seed, random_span_len=args.random_span_len, limit=args.limit, ) if __name__ == "__main__": main()