VLAlert / tools /make_belief_cache_x.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
16.7 kB
"""VLAlert-X belief cache extractor β€” multi-layer + action-pool, per-frame.
Reads a cot_belief_dataset-format JSONL manifest (e.g.
data/cot_corpus_v2/vlalert_x_sft.jsonl), forwards each clip through the
SFT'd Qwen3-VL-4B + LoRA, and saves per-frame belief vectors at the
action-token positions, with the last `n_layers` transformer layers
concatenated.
Output schema (mirrors `data/belief_cache_perframe_qwen3vl4b/*.pt`):
{
"beliefs_frame": [N, 8, n_layers*D] fp16 (D=2560 β†’ 10240 if L=4)
"valid_frames": [N, 8] bool
"ids": list[str] (clip_id per row)
"category": list[str] (ego_positive/safe_neg)
"source": list[str] (nexar/dada/...)
"action_per_frame": list[list[str]] (oracle, from manifest)
"tta_raw": [N] float (clip-level TTA)
"schema": "vlalert_x_belief_v1"
"n_layers": int
"pool_mode": str
}
The action-pool mode finds the per-frame action token positions in the
assistant string and reads the hidden state at each. Falls back to
BELIEF-open positions if action_pool returns wrong number of tokens.
Usage (single pass, single manifest):
python tools/make_belief_cache_x.py \
--ckpt checkpoints/sft_x/best \
--manifest data/cot_corpus_v2/vlalert_x_sft.jsonl \
--out data/belief_cache_x/sft_x__action.pt \
--n_layers 4 --pool_mode action
Designed to be called by tools/extract_3window_cache.py, once per
{split, window} combination.
"""
from __future__ import annotations
# Apply Conv3d→Linear patch BEFORE any model load
import sys; sys.path.insert(0, ".")
from tools import run_train_cot_belief_fast # noqa: F401
import argparse
import json
import logging
import time
from pathlib import Path
from typing import Dict, List, Optional
import torch
from tqdm import tqdm
ROOT = Path(__file__).resolve().parents[1]
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("make_belief_cache_x")
def extract_per_frame_beliefs(
ckpt_dir: Path,
base_model: Path,
manifest_path: Path,
out_path: Path,
n_frames: int = 8,
n_layers: int = 4,
pool_mode: str = "action",
random_span_seed: int = 0,
random_span_len: int = 25,
limit: int = 0,
):
"""Extract per-frame belief cache for VLAlert-X."""
if out_path.exists():
logger.info(f"[skip] {out_path} exists; reuse")
return
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
from peft import PeftModel
from training.VLA.cot_belief_dataset import (
ALL_SPECIAL, BELIEF_OPEN, BELIEF_CLOSE,
ACTION_ALERT, ACTION_OBSERVE, ACTION_SILENT,
build_chat, format_assistant, _resolve_actions,
)
from training.VLA.frame_utils import sample_frames
logger.info(f"[load] base_model={base_model} ckpt={ckpt_dir}")
logger.info(f" n_layers={n_layers} pool_mode={pool_mode}")
processor = AutoProcessor.from_pretrained(base_model, trust_remote_code=True)
processor.tokenizer.add_special_tokens({"additional_special_tokens": ALL_SPECIAL})
model = Qwen3VLForConditionalGeneration.from_pretrained(
base_model, torch_dtype=torch.bfloat16, device_map="auto",
trust_remote_code=True)
model.resize_token_embeddings(len(processor.tokenizer))
if (ckpt_dir / "adapter_config.json").exists():
model = PeftModel.from_pretrained(model, ckpt_dir)
model.eval()
tok = processor.tokenizer
belief_open_id = tok.convert_tokens_to_ids(BELIEF_OPEN)
belief_close_id = tok.convert_tokens_to_ids(BELIEF_CLOSE)
action_ids = {tok.convert_tokens_to_ids(t)
for t in (ACTION_SILENT, ACTION_OBSERVE, ACTION_ALERT)}
# ── load manifest (allow stub-CoT records for val/policy_labels) ──
def _ensure_record(r: Dict) -> Optional[Dict]:
"""If record lacks cot/belief, synthesise a stub so the assistant
string still has 8 BELIEF blocks. Action labels are derived from
whatever the manifest provides (or all-SILENT)."""
if not r.get("video_path"):
return None
if r.get("cot") and r.get("belief", {}).get("frame_indices"):
return r
# stub mode
action_lbl = r.get("action_label", 0)
clip_action = {0: "SILENT", 1: "OBSERVE", 2: "ALERT"}.get(int(action_lbl), "SILENT")
actions_pf = r.get("actions_per_frame") or [clip_action] * n_frames
if len(actions_pf) != n_frames:
actions_pf = (actions_pf + [clip_action] * n_frames)[:n_frames]
frame_idx = (r.get("frame_indices") or
(r.get("belief") or {}).get("frame_indices"))
if not frame_idx:
return None
return {
"id": r.get("id") or r.get("video_id", ""),
"video_path": r["video_path"],
"category": r.get("category", ""),
"source": r.get("source", ""),
"tta_raw": r.get("tta_raw", -1.0),
"cot": {
"scene": "(n/a)",
"critical_objects": [],
"threat_analysis": "(n/a)",
},
"belief": {
"action": clip_action,
"actions_per_frame": actions_pf,
"frame_indices": frame_idx,
},
}
records: List[Dict] = []
n_stub = 0
with open(manifest_path) as f:
for ln in f:
ln = ln.strip()
if not ln: continue
try:
r = json.loads(ln)
rec = _ensure_record(r)
if rec is not None:
if not r.get("cot"):
n_stub += 1
records.append(rec)
except Exception:
pass
if limit > 0:
records = records[:limit]
logger.info(f"[load] {manifest_path} n={len(records)} stub_cot={n_stub}")
# ── allocate output tensors ─────────────────────────────────────
# We don't know D until first forward; allocate after first sample
out_beliefs: Optional[torch.Tensor] = None
out_valid = torch.zeros(len(records), n_frames, dtype=torch.bool)
ids_list, cat_list, src_list, actions_list = [], [], [], []
tta_list = torch.zeros(len(records), dtype=torch.float32)
n_failed = 0
n_pool_fallback = 0
t0 = time.time()
for i, rec in enumerate(tqdm(records, ncols=80, desc="cache_x")):
try:
video_path = rec["video_path"]
frame_idx = rec["belief"].get("frame_indices")
frames = sample_frames(video_path, n_frames=n_frames,
resize_short=336,
frame_indices=frame_idx)
actions = _resolve_actions(rec["belief"], n_frames)
assistant_text = format_assistant(rec["cot"], actions)
full_msgs = build_chat(frames, assistant_text=assistant_text)
full_text = processor.apply_chat_template(
full_msgs, tokenize=False, add_generation_prompt=False)
inputs = processor(text=[full_text], images=[frames],
return_tensors="pt", padding=False,
truncation=True, max_length=4096)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
out = model(**inputs, output_hidden_states=True,
return_dict=True)
# multi-layer concat: [T, n_layers * D]
if n_layers == 1:
hs = out.hidden_states[-1][0]
else:
hs_list = [out.hidden_states[k][0]
for k in range(-n_layers, 0)]
hs = torch.cat(hs_list, dim=-1)
ids_t = inputs["input_ids"][0]
T_total, D_full = hs.shape
# find per-frame pool positions
if pool_mode == "action":
# one action token per frame (in causal order)
pos_list = [int(p) for p, t in enumerate(ids_t.tolist())
if t in action_ids]
elif pool_mode == "open":
pos_list = (ids_t == belief_open_id).nonzero(
as_tuple=False).flatten().tolist()
elif pool_mode == "range":
opens = (ids_t == belief_open_id).nonzero(
as_tuple=False).flatten().tolist()
closes = (ids_t == belief_close_id).nonzero(
as_tuple=False).flatten().tolist()
# group into per-frame mean ranges
pos_list = [] # not used; we pool per-range below
elif pool_mode == "token_mean":
# Format-agnostic baseline: mean over ALL valid (non-image, non-pad)
# tokens of the assistant response. Replicated across n_frames so
# the downstream tensor shape matches V0.
pos_list = []
elif pool_mode == "random_span":
# Control baseline: same span length as BELIEF (default 25 tokens)
# but at random positions in the response. Same per-frame structure
# as V0 (n_frames independent random spans).
pos_list = []
else:
raise ValueError(f"pool_mode={pool_mode}")
# Lazy-allocate output tensor
if out_beliefs is None:
out_beliefs = torch.zeros(len(records), n_frames, D_full,
dtype=torch.float16)
# ── case 1: per-position single-vector pool ──
if pool_mode in ("action", "open") and len(pos_list) >= 1:
# take first n_frames positions
use_pos = pos_list[:n_frames]
if len(use_pos) < n_frames:
n_pool_fallback += 1
for f, p in enumerate(use_pos):
out_beliefs[i, f] = hs[p].float().to(torch.float16).cpu()
out_valid[i, f] = True
# ── case 2: range pool β€” mean over each <|BELIEF|>...</|BELIEF|> ──
elif pool_mode == "range" and len(opens) >= 1 and len(closes) >= 1:
pairs = list(zip(opens[:n_frames], closes[:n_frames]))
for f, (o, c) in enumerate(pairs):
if c > o:
out_beliefs[i, f] = hs[o:c+1].mean(dim=0).float().to(
torch.float16).cpu()
out_valid[i, f] = True
# ── case 3 (V1): token-mean pool β€” mean over ALL response tokens ──
elif pool_mode == "token_mean":
# Find the assistant-response span: from first BELIEF-open to last
# token. This excludes the user prompt and image tokens.
opens_local = (ids_t == belief_open_id).nonzero(
as_tuple=False).flatten().tolist()
resp_start = opens_local[0] if opens_local else max(0, T_total - 200)
pooled = hs[resp_start:].mean(dim=0)
for f in range(n_frames):
out_beliefs[i, f] = pooled.float().to(torch.float16).cpu()
out_valid[i, f] = True
# ── case 4 (V4): random-span pool β€” same-length spans at random positions ──
elif pool_mode == "random_span":
# Use deterministic per-sample RNG so the cache is reproducible.
import random as _rnd
rng = _rnd.Random(int(random_span_seed) * 100003 + i)
# Estimate span length from actual BELIEF spans on this sample
opens_local = (ids_t == belief_open_id).nonzero(
as_tuple=False).flatten().tolist()
closes_local = (ids_t == belief_close_id).nonzero(
as_tuple=False).flatten().tolist()
if opens_local and closes_local and len(opens_local) >= 1:
span_lens = [c - o for o, c in zip(opens_local, closes_local) if c > o]
L_span = max(3, int(round(sum(span_lens) / max(len(span_lens), 1))))
else:
L_span = int(random_span_len)
resp_start = opens_local[0] if opens_local else max(0, T_total - 200)
resp_end = T_total
if resp_end - resp_start <= L_span:
# response too short β€” just mean the available range
pooled = hs[resp_start:resp_end].mean(dim=0)
for f in range(n_frames):
out_beliefs[i, f] = pooled.float().to(torch.float16).cpu()
out_valid[i, f] = True
else:
for f in range(n_frames):
start = rng.randint(resp_start, resp_end - L_span)
out_beliefs[i, f] = hs[start:start + L_span].mean(dim=0).float().to(
torch.float16).cpu()
out_valid[i, f] = True
else:
# fallback: mean-pool last 64 tokens, replicate across frames
pooled = hs[-64:].mean(dim=0)
for f in range(n_frames):
out_beliefs[i, f] = pooled.float().to(torch.float16).cpu()
# leave valid_frames = False
n_pool_fallback += 1
ids_list.append(rec.get("id", str(i)))
cat_list.append(rec.get("category", ""))
src_list.append(rec.get("source", ""))
actions_list.append(actions)
tta_list[i] = float(rec.get("tta_raw", -1.0))
except Exception as e:
n_failed += 1
logger.warning(f"[skip] {rec.get('id')}: {e}")
ids_list.append(rec.get("id", str(i)))
cat_list.append(rec.get("category", ""))
src_list.append(rec.get("source", ""))
actions_list.append([])
continue
if out_beliefs is None:
raise RuntimeError("no successful extractions")
out_dict = {
"beliefs_frame": out_beliefs,
"valid_frames": out_valid,
"ids": ids_list,
"category": cat_list,
"source": src_list,
"action_per_frame": actions_list,
"tta_raw": tta_list,
"schema": "vlalert_x_belief_v1",
"n_layers": n_layers,
"pool_mode": pool_mode,
"belief_dim": out_beliefs.shape[-1],
"ckpt": str(ckpt_dir),
}
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(out_dict, out_path)
dt = time.time() - t0
logger.info(f"[save] {out_path}")
logger.info(f" shape={out_beliefs.shape} failed={n_failed} "
f"fallback={n_pool_fallback} elapsed={dt:.0f}s")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", type=Path, required=True)
ap.add_argument("--base_model", type=Path,
default=ROOT / "models/Qwen3-VL-4B-Instruct")
ap.add_argument("--manifest", type=Path, required=True)
ap.add_argument("--out", type=Path, required=True)
ap.add_argument("--n_frames", type=int, default=8)
ap.add_argument("--n_layers", type=int, default=4)
ap.add_argument("--random_span_seed", type=int, default=0,
help="RNG seed for --pool_mode random_span (deterministic per-sample)")
ap.add_argument("--random_span_len", type=int, default=25,
help="fallback span length for --pool_mode random_span when "
"no BELIEF tags found on a sample")
ap.add_argument("--pool_mode",
choices=["open", "range", "action", "token_mean", "random_span"],
default="action")
ap.add_argument("--limit", type=int, default=0,
help="If >0, truncate manifest to first N rows (smoke test)")
args = ap.parse_args()
extract_per_frame_beliefs(
args.ckpt, args.base_model, args.manifest, args.out,
n_frames=args.n_frames, n_layers=args.n_layers,
pool_mode=args.pool_mode,
random_span_seed=args.random_span_seed,
random_span_len=args.random_span_len,
limit=args.limit,
)
if __name__ == "__main__":
main()