| """VLAlert-X belief cache extractor β multi-layer + action-pool, per-frame. |
| |
| Reads a cot_belief_dataset-format JSONL manifest (e.g. |
| data/cot_corpus_v2/vlalert_x_sft.jsonl), forwards each clip through the |
| SFT'd Qwen3-VL-4B + LoRA, and saves per-frame belief vectors at the |
| action-token positions, with the last `n_layers` transformer layers |
| concatenated. |
| |
| Output schema (mirrors `data/belief_cache_perframe_qwen3vl4b/*.pt`): |
| |
| { |
| "beliefs_frame": [N, 8, n_layers*D] fp16 (D=2560 β 10240 if L=4) |
| "valid_frames": [N, 8] bool |
| "ids": list[str] (clip_id per row) |
| "category": list[str] (ego_positive/safe_neg) |
| "source": list[str] (nexar/dada/...) |
| "action_per_frame": list[list[str]] (oracle, from manifest) |
| "tta_raw": [N] float (clip-level TTA) |
| "schema": "vlalert_x_belief_v1" |
| "n_layers": int |
| "pool_mode": str |
| } |
| |
| The action-pool mode finds the per-frame action token positions in the |
| assistant string and reads the hidden state at each. Falls back to |
| BELIEF-open positions if action_pool returns wrong number of tokens. |
| |
| Usage (single pass, single manifest): |
| python tools/make_belief_cache_x.py \ |
| --ckpt checkpoints/sft_x/best \ |
| --manifest data/cot_corpus_v2/vlalert_x_sft.jsonl \ |
| --out data/belief_cache_x/sft_x__action.pt \ |
| --n_layers 4 --pool_mode action |
| |
| Designed to be called by tools/extract_3window_cache.py, once per |
| {split, window} combination. |
| """ |
| from __future__ import annotations |
|
|
| |
| import sys; sys.path.insert(0, ".") |
| from tools import run_train_cot_belief_fast |
|
|
| import argparse |
| import json |
| import logging |
| import time |
| from pathlib import Path |
| from typing import Dict, List, Optional |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| logging.basicConfig(level=logging.INFO, |
| format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("make_belief_cache_x") |
|
|
|
|
| def extract_per_frame_beliefs( |
| ckpt_dir: Path, |
| base_model: Path, |
| manifest_path: Path, |
| out_path: Path, |
| n_frames: int = 8, |
| n_layers: int = 4, |
| pool_mode: str = "action", |
| random_span_seed: int = 0, |
| random_span_len: int = 25, |
| limit: int = 0, |
| ): |
| """Extract per-frame belief cache for VLAlert-X.""" |
| if out_path.exists(): |
| logger.info(f"[skip] {out_path} exists; reuse") |
| return |
|
|
| from transformers import AutoProcessor, Qwen3VLForConditionalGeneration |
| from peft import PeftModel |
| from training.VLA.cot_belief_dataset import ( |
| ALL_SPECIAL, BELIEF_OPEN, BELIEF_CLOSE, |
| ACTION_ALERT, ACTION_OBSERVE, ACTION_SILENT, |
| build_chat, format_assistant, _resolve_actions, |
| ) |
| from training.VLA.frame_utils import sample_frames |
|
|
| logger.info(f"[load] base_model={base_model} ckpt={ckpt_dir}") |
| logger.info(f" n_layers={n_layers} pool_mode={pool_mode}") |
|
|
| processor = AutoProcessor.from_pretrained(base_model, trust_remote_code=True) |
| processor.tokenizer.add_special_tokens({"additional_special_tokens": ALL_SPECIAL}) |
| model = Qwen3VLForConditionalGeneration.from_pretrained( |
| base_model, torch_dtype=torch.bfloat16, device_map="auto", |
| trust_remote_code=True) |
| model.resize_token_embeddings(len(processor.tokenizer)) |
| if (ckpt_dir / "adapter_config.json").exists(): |
| model = PeftModel.from_pretrained(model, ckpt_dir) |
| model.eval() |
|
|
| tok = processor.tokenizer |
| belief_open_id = tok.convert_tokens_to_ids(BELIEF_OPEN) |
| belief_close_id = tok.convert_tokens_to_ids(BELIEF_CLOSE) |
| action_ids = {tok.convert_tokens_to_ids(t) |
| for t in (ACTION_SILENT, ACTION_OBSERVE, ACTION_ALERT)} |
|
|
| |
| def _ensure_record(r: Dict) -> Optional[Dict]: |
| """If record lacks cot/belief, synthesise a stub so the assistant |
| string still has 8 BELIEF blocks. Action labels are derived from |
| whatever the manifest provides (or all-SILENT).""" |
| if not r.get("video_path"): |
| return None |
| if r.get("cot") and r.get("belief", {}).get("frame_indices"): |
| return r |
| |
| action_lbl = r.get("action_label", 0) |
| clip_action = {0: "SILENT", 1: "OBSERVE", 2: "ALERT"}.get(int(action_lbl), "SILENT") |
| actions_pf = r.get("actions_per_frame") or [clip_action] * n_frames |
| if len(actions_pf) != n_frames: |
| actions_pf = (actions_pf + [clip_action] * n_frames)[:n_frames] |
| frame_idx = (r.get("frame_indices") or |
| (r.get("belief") or {}).get("frame_indices")) |
| if not frame_idx: |
| return None |
| return { |
| "id": r.get("id") or r.get("video_id", ""), |
| "video_path": r["video_path"], |
| "category": r.get("category", ""), |
| "source": r.get("source", ""), |
| "tta_raw": r.get("tta_raw", -1.0), |
| "cot": { |
| "scene": "(n/a)", |
| "critical_objects": [], |
| "threat_analysis": "(n/a)", |
| }, |
| "belief": { |
| "action": clip_action, |
| "actions_per_frame": actions_pf, |
| "frame_indices": frame_idx, |
| }, |
| } |
|
|
| records: List[Dict] = [] |
| n_stub = 0 |
| with open(manifest_path) as f: |
| for ln in f: |
| ln = ln.strip() |
| if not ln: continue |
| try: |
| r = json.loads(ln) |
| rec = _ensure_record(r) |
| if rec is not None: |
| if not r.get("cot"): |
| n_stub += 1 |
| records.append(rec) |
| except Exception: |
| pass |
| if limit > 0: |
| records = records[:limit] |
| logger.info(f"[load] {manifest_path} n={len(records)} stub_cot={n_stub}") |
|
|
| |
| |
| out_beliefs: Optional[torch.Tensor] = None |
| out_valid = torch.zeros(len(records), n_frames, dtype=torch.bool) |
| ids_list, cat_list, src_list, actions_list = [], [], [], [] |
| tta_list = torch.zeros(len(records), dtype=torch.float32) |
|
|
| n_failed = 0 |
| n_pool_fallback = 0 |
| t0 = time.time() |
| for i, rec in enumerate(tqdm(records, ncols=80, desc="cache_x")): |
| try: |
| video_path = rec["video_path"] |
| frame_idx = rec["belief"].get("frame_indices") |
| frames = sample_frames(video_path, n_frames=n_frames, |
| resize_short=336, |
| frame_indices=frame_idx) |
| actions = _resolve_actions(rec["belief"], n_frames) |
| assistant_text = format_assistant(rec["cot"], actions) |
| full_msgs = build_chat(frames, assistant_text=assistant_text) |
| full_text = processor.apply_chat_template( |
| full_msgs, tokenize=False, add_generation_prompt=False) |
| inputs = processor(text=[full_text], images=[frames], |
| return_tensors="pt", padding=False, |
| truncation=True, max_length=4096) |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| out = model(**inputs, output_hidden_states=True, |
| return_dict=True) |
|
|
| |
| if n_layers == 1: |
| hs = out.hidden_states[-1][0] |
| else: |
| hs_list = [out.hidden_states[k][0] |
| for k in range(-n_layers, 0)] |
| hs = torch.cat(hs_list, dim=-1) |
| ids_t = inputs["input_ids"][0] |
| T_total, D_full = hs.shape |
|
|
| |
| if pool_mode == "action": |
| |
| pos_list = [int(p) for p, t in enumerate(ids_t.tolist()) |
| if t in action_ids] |
| elif pool_mode == "open": |
| pos_list = (ids_t == belief_open_id).nonzero( |
| as_tuple=False).flatten().tolist() |
| elif pool_mode == "range": |
| opens = (ids_t == belief_open_id).nonzero( |
| as_tuple=False).flatten().tolist() |
| closes = (ids_t == belief_close_id).nonzero( |
| as_tuple=False).flatten().tolist() |
| |
| pos_list = [] |
| elif pool_mode == "token_mean": |
| |
| |
| |
| pos_list = [] |
| elif pool_mode == "random_span": |
| |
| |
| |
| pos_list = [] |
| else: |
| raise ValueError(f"pool_mode={pool_mode}") |
|
|
| |
| if out_beliefs is None: |
| out_beliefs = torch.zeros(len(records), n_frames, D_full, |
| dtype=torch.float16) |
|
|
| |
| if pool_mode in ("action", "open") and len(pos_list) >= 1: |
| |
| use_pos = pos_list[:n_frames] |
| if len(use_pos) < n_frames: |
| n_pool_fallback += 1 |
| for f, p in enumerate(use_pos): |
| out_beliefs[i, f] = hs[p].float().to(torch.float16).cpu() |
| out_valid[i, f] = True |
| |
| elif pool_mode == "range" and len(opens) >= 1 and len(closes) >= 1: |
| pairs = list(zip(opens[:n_frames], closes[:n_frames])) |
| for f, (o, c) in enumerate(pairs): |
| if c > o: |
| out_beliefs[i, f] = hs[o:c+1].mean(dim=0).float().to( |
| torch.float16).cpu() |
| out_valid[i, f] = True |
| |
| elif pool_mode == "token_mean": |
| |
| |
| opens_local = (ids_t == belief_open_id).nonzero( |
| as_tuple=False).flatten().tolist() |
| resp_start = opens_local[0] if opens_local else max(0, T_total - 200) |
| pooled = hs[resp_start:].mean(dim=0) |
| for f in range(n_frames): |
| out_beliefs[i, f] = pooled.float().to(torch.float16).cpu() |
| out_valid[i, f] = True |
| |
| elif pool_mode == "random_span": |
| |
| import random as _rnd |
| rng = _rnd.Random(int(random_span_seed) * 100003 + i) |
| |
| opens_local = (ids_t == belief_open_id).nonzero( |
| as_tuple=False).flatten().tolist() |
| closes_local = (ids_t == belief_close_id).nonzero( |
| as_tuple=False).flatten().tolist() |
| if opens_local and closes_local and len(opens_local) >= 1: |
| span_lens = [c - o for o, c in zip(opens_local, closes_local) if c > o] |
| L_span = max(3, int(round(sum(span_lens) / max(len(span_lens), 1)))) |
| else: |
| L_span = int(random_span_len) |
| resp_start = opens_local[0] if opens_local else max(0, T_total - 200) |
| resp_end = T_total |
| if resp_end - resp_start <= L_span: |
| |
| pooled = hs[resp_start:resp_end].mean(dim=0) |
| for f in range(n_frames): |
| out_beliefs[i, f] = pooled.float().to(torch.float16).cpu() |
| out_valid[i, f] = True |
| else: |
| for f in range(n_frames): |
| start = rng.randint(resp_start, resp_end - L_span) |
| out_beliefs[i, f] = hs[start:start + L_span].mean(dim=0).float().to( |
| torch.float16).cpu() |
| out_valid[i, f] = True |
| else: |
| |
| pooled = hs[-64:].mean(dim=0) |
| for f in range(n_frames): |
| out_beliefs[i, f] = pooled.float().to(torch.float16).cpu() |
| |
| n_pool_fallback += 1 |
|
|
| ids_list.append(rec.get("id", str(i))) |
| cat_list.append(rec.get("category", "")) |
| src_list.append(rec.get("source", "")) |
| actions_list.append(actions) |
| tta_list[i] = float(rec.get("tta_raw", -1.0)) |
| except Exception as e: |
| n_failed += 1 |
| logger.warning(f"[skip] {rec.get('id')}: {e}") |
| ids_list.append(rec.get("id", str(i))) |
| cat_list.append(rec.get("category", "")) |
| src_list.append(rec.get("source", "")) |
| actions_list.append([]) |
| continue |
|
|
| if out_beliefs is None: |
| raise RuntimeError("no successful extractions") |
|
|
| out_dict = { |
| "beliefs_frame": out_beliefs, |
| "valid_frames": out_valid, |
| "ids": ids_list, |
| "category": cat_list, |
| "source": src_list, |
| "action_per_frame": actions_list, |
| "tta_raw": tta_list, |
| "schema": "vlalert_x_belief_v1", |
| "n_layers": n_layers, |
| "pool_mode": pool_mode, |
| "belief_dim": out_beliefs.shape[-1], |
| "ckpt": str(ckpt_dir), |
| } |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(out_dict, out_path) |
| dt = time.time() - t0 |
| logger.info(f"[save] {out_path}") |
| logger.info(f" shape={out_beliefs.shape} failed={n_failed} " |
| f"fallback={n_pool_fallback} elapsed={dt:.0f}s") |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ckpt", type=Path, required=True) |
| ap.add_argument("--base_model", type=Path, |
| default=ROOT / "models/Qwen3-VL-4B-Instruct") |
| ap.add_argument("--manifest", type=Path, required=True) |
| ap.add_argument("--out", type=Path, required=True) |
| ap.add_argument("--n_frames", type=int, default=8) |
| ap.add_argument("--n_layers", type=int, default=4) |
| ap.add_argument("--random_span_seed", type=int, default=0, |
| help="RNG seed for --pool_mode random_span (deterministic per-sample)") |
| ap.add_argument("--random_span_len", type=int, default=25, |
| help="fallback span length for --pool_mode random_span when " |
| "no BELIEF tags found on a sample") |
| ap.add_argument("--pool_mode", |
| choices=["open", "range", "action", "token_mean", "random_span"], |
| default="action") |
| ap.add_argument("--limit", type=int, default=0, |
| help="If >0, truncate manifest to first N rows (smoke test)") |
| args = ap.parse_args() |
|
|
| extract_per_frame_beliefs( |
| args.ckpt, args.base_model, args.manifest, args.out, |
| n_frames=args.n_frames, n_layers=args.n_layers, |
| pool_mode=args.pool_mode, |
| random_span_seed=args.random_span_seed, |
| random_span_len=args.random_span_len, |
| limit=args.limit, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|