#!/usr/bin/env python3 """ make_cot_belief_cache.py ═══════════════════════════════════════════════════════════════════════════════ Per-frame belief cache extraction for the CoT+BeliefToken Qwen3-VL-4B checkpoint (output of training/VLA/train_cot_belief.py). Why a new script: make_belief_cache_v2.py is glued to PolicyModel / SFTModel, which expect {config.json, vlm_lora/, hazard_head.pt, tta_head.pt}. The CoT+BeliefToken checkpoint has a different layout (pure PEFT adapter; tokenizer extended with 5 new tokens; no aux heads). This script loads the PEFT adapter directly, runs the same per-frame visual-token pooling, and writes a cache identical in schema to the v2 per_frame format, so existing temporal heads (temporal_long, traj_full_long, etc.) can consume it with --hidden_dim 2560. Output schema (matches v2 per_frame): beliefs_frame [N, T, D] fp16 — per-frame pooled visual token hiddens valid_frames [N, T] bool — True where a frame was present beliefs_text [N, D] fp16 — mean of non-image valid tokens tta_means [N] fp32 — zeros (no tta_head on this backbone) tta_vars [N] fp32 — ones (variance placeholder) meta dict — schema_version, hidden_dim, n_frames, ids, labels, ... Usage ───── python -m training.Policy.make_cot_belief_cache \\ --ckpt_dir checkpoints/VLA/qwen3vl4b_cot_belief/best \\ --base_model models/Qwen3-VL-4B-Instruct \\ --split val \\ --out data/belief_cache_qwen3vl4b_multisrc/val_perframe_t16.pt \\ --n_frames 16 --sampling last_biased --chunk_size 2000 """ from __future__ import annotations import argparse import json import logging import shutil import sys from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn.functional as F from torch.amp import autocast from torch.utils.data import DataLoader from tqdm import tqdm sys.path.insert(0, str(Path(__file__).resolve().parents[2])) from peft import PeftModel from transformers import AutoModelForImageTextToText, AutoProcessor from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("Policy.make_cot_belief_cache") SCHEMA_VERSION = 3 # bumped: Qwen3-VL-4B + CoT+BeliefToken backbone SYSTEM_PROMPT = ( "You are a driving-safety assistant. Given N dashcam frames (earliest → latest), " "produce a short chain-of-thought analysis and then emit a single risk action token " "wrapped in <|BELIEF|> ... . " "The action is <|ALERT|> (imminent collision < ~1.5s), " "<|OBSERVE|> (near-term threat, ~1.5-4s), or <|SILENT|> (no threat). " "Keep prose minimal; the <|BELIEF|> block is mandatory." ) USER_PROMPT = "Analyze the frames and emit scene analysis + belief block." # ── model loader ──────────────────────────────────────────────────────────── def load_model(base_model: str, ckpt_dir: str, attn_impl: str = "flash_attention_2") -> Tuple[AutoModelForImageTextToText, AutoProcessor]: logger.info(f"Loading processor (w/ special tokens) from {ckpt_dir}") processor = AutoProcessor.from_pretrained(ckpt_dir, trust_remote_code=True) logger.info(f"Loading base model {base_model} (bf16)") model = AutoModelForImageTextToText.from_pretrained( base_model, torch_dtype=torch.bfloat16, trust_remote_code=True, attn_implementation=attn_impl, ) # Resize to match extended vocab so PEFT adapter's modules_to_save # (embed_tokens, lm_head) can be loaded cleanly. new_vocab = len(processor.tokenizer) if model.get_input_embeddings().weight.shape[0] != new_vocab: logger.info(f"Resizing embeddings: " f"{model.get_input_embeddings().weight.shape[0]} -> {new_vocab}") model.resize_token_embeddings(new_vocab) logger.info(f"Attaching PEFT adapter from {ckpt_dir}") peft_model = PeftModel.from_pretrained(model, ckpt_dir, is_trainable=False) # Merge LoRA into base weights for much faster inference (LoRA forward has # ~2-3× overhead per attn/mlp layer). modules_to_save (embed_tokens, lm_head) # are kept as-is after merge. logger.info(" merging LoRA adapters into base weights (inference-only)") model = peft_model.merge_and_unload() model.eval() model.to("cuda") hs = _config_hidden_size(model.config) logger.info(f" hidden_size = {hs}") return model, processor def _config_hidden_size(cfg) -> int: return int(getattr(cfg, "hidden_size", None) or cfg.text_config.hidden_size) def _config_spatial_merge_size(cfg) -> int: vc = getattr(cfg, "vision_config", None) return int(getattr(vc, "spatial_merge_size", 2) if vc is not None else 2) # ── per-frame token splitting (mirrors v2) ────────────────────────────────── def _per_image_token_counts(image_grid_thw: torch.Tensor, sms: int) -> List[int]: sms2 = sms * sms return [int((r[0] * r[1] * r[2]) // sms2) for r in image_grid_thw.tolist()] def _split_visual_tokens(hs_b: torch.Tensor, ids_b: torch.Tensor, attn_b: torch.Tensor, igt_b: torch.Tensor, image_token_id: int, sms: int) -> List[torch.Tensor]: """Return list of [count_i, D] per-image hidden slices for one sample.""" valid = attn_b > 0 is_img = (ids_b == image_token_id) & valid positions = torch.nonzero(is_img, as_tuple=False).squeeze(-1) n_img_tokens = int(positions.numel()) counts = _per_image_token_counts(igt_b, sms) if n_img_tokens != sum(counts): raise RuntimeError( f"image-token count mismatch: {n_img_tokens} vs {sum(counts)} " f"(igt={igt_b.tolist()})" ) chunks: List[torch.Tensor] = [] cursor = 0 for c in counts: chunks.append(hs_b[positions[cursor:cursor + c]]) cursor += c return chunks # ── input builder ────────────────────────────────────────────────────────── def _resize_short(img, short: int): w, h = img.size if min(w, h) <= short: return img if w < h: nw = short; nh = int(round(h * (short / w))) else: nh = short; nw = int(round(w * (short / h))) return img.resize((nw, nh)) def _build_inputs(processor, images_b: List[List], metadata_b: List[dict], resize_short: int = 336): """Build the same chat template used during CoT+BeliefToken training, but without the assistant turn (we only need the visual tokens). Frames are resized to `resize_short` (matches training default) to keep visual-token counts bounded.""" texts: List[str] = [] images_b_resized = [[_resize_short(img, resize_short) for img in frames] for frames in images_b] for frames in images_b_resized: user_content = [{"type": "image", "image": img} for img in frames] user_content.append({"type": "text", "text": USER_PROMPT}) msgs = [ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, {"role": "user", "content": user_content}, ] texts.append( processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) ) return processor(text=texts, images=images_b_resized, return_tensors="pt", padding=True, truncation=False) # ── extract one batch ────────────────────────────────────────────────────── @torch.no_grad() def extract_batch(model, processor, inputs: Dict[str, torch.Tensor], image_token_id: int, sms: int, n_frames: int, amp_dtype=torch.bfloat16) -> Dict[str, torch.Tensor]: device = next(model.parameters()).device moved: Dict[str, torch.Tensor] = {} for k, v in inputs.items(): if not isinstance(v, torch.Tensor): moved[k] = v; continue if k == "pixel_values": moved[k] = v.to(device, dtype=amp_dtype, non_blocking=True) else: moved[k] = v.to(device, non_blocking=True) # After merge_and_unload() model is a plain HF model; otherwise it's PeftModel. base = model.get_base_model() if hasattr(model, "get_base_model") else model core = getattr(base, "model", None) with autocast(device_type="cuda", dtype=amp_dtype, enabled=True): if core is not None: out = core( input_ids = moved["input_ids"], attention_mask = moved.get("attention_mask"), pixel_values = moved.get("pixel_values"), image_grid_thw = moved.get("image_grid_thw"), use_cache = False, return_dict = True, ) hs = out.last_hidden_state if hasattr(out, "last_hidden_state") else out[0] else: out = base( input_ids = moved["input_ids"], attention_mask = moved.get("attention_mask"), pixel_values = moved.get("pixel_values"), image_grid_thw = moved.get("image_grid_thw"), use_cache = False, return_dict = True, output_hidden_states = True, ) hs = out.hidden_states[-1] B, _, D = hs.shape attn = moved.get("attention_mask") ids = moved.get("input_ids") igt = moved.get("image_grid_thw") beliefs_frame = torch.zeros(B, n_frames, D, dtype=torch.float16) valid_frames = torch.zeros(B, n_frames, dtype=torch.bool) beliefs_text = torch.zeros(B, D, dtype=torch.float16) igt_cursor = 0 for b in range(B): ids_b = ids[b] attn_b = attn[b] if attn is not None else torch.ones_like(ids_b) hs_b = hs[b] valid = attn_b > 0 is_img_b = (ids_b == image_token_id) & valid # Count contiguous image-token runs = number of images in this sample x = is_img_b.to(torch.int8) diff = torch.cat([x.new_zeros(1), x[1:] - x[:-1]]) n_imgs = int((diff == 1).sum().item()) if n_imgs > 0: igt_b = igt[igt_cursor:igt_cursor + n_imgs] igt_cursor += n_imgs chunks = _split_visual_tokens(hs_b, ids_b, attn_b, igt_b, image_token_id, sms) for f in range(min(len(chunks), n_frames)): beliefs_frame[b, f] = chunks[f].float().mean(dim=0).to(torch.float16).cpu() valid_frames[b, f] = True # Text pool: non-image valid tokens is_text_b = (~is_img_b) & valid m_text = is_text_b.unsqueeze(-1).to(hs_b.dtype) denom = m_text.sum(dim=0).clamp(min=1e-6) t_mean = (hs_b * m_text).sum(dim=0) / denom beliefs_text[b] = t_mean.to(torch.float16).cpu() return { "beliefs_frame": beliefs_frame, "valid_frames": valid_frames, "beliefs_text": beliefs_text, # tta placeholders — shape matches v2 schema "tta_means": torch.zeros(B, dtype=torch.float32), "tta_vars": torch.ones(B, dtype=torch.float32), } # ── chunked save/resume (mirrors v2 helpers) ─────────────────────────────── def _flush_chunk(acc, chunk_dir: Path, idx: int) -> int: if not acc: return 0 part = {k: torch.cat(v, dim=0) for k, v in acc.items()} n = next(iter(part.values())).shape[0] tmp = chunk_dir / f"chunk_{idx:05d}.pt.tmp" fin = chunk_dir / f"chunk_{idx:05d}.pt" torch.save(part, tmp); tmp.rename(fin) return n def _scan_chunks(chunk_dir: Path) -> Tuple[int, int]: if not chunk_dir.exists(): return 0, 0 for t in chunk_dir.glob("*.tmp"): t.unlink(missing_ok=True) files = sorted(chunk_dir.glob("chunk_*.pt")) n_samples = 0 for f in files: try: d = torch.load(f, map_location="cpu", weights_only=True) n_samples += int(next(iter(d.values())).shape[0]) except Exception as e: logger.warning(f" [resume] dropping unreadable chunk {f.name}: {e}") f.unlink(missing_ok=True) return len(list(chunk_dir.glob("chunk_*.pt"))), n_samples def _merge_chunks(chunk_dir: Path) -> Dict[str, torch.Tensor]: files = sorted(chunk_dir.glob("chunk_*.pt")) if not files: return {} acc: Dict[str, List[torch.Tensor]] = {} for f in files: d = torch.load(f, map_location="cpu", weights_only=True) for k, v in d.items(): acc.setdefault(k, []).append(v) return {k: torch.cat(lst, dim=0) for k, lst in acc.items()} # ── build cache ──────────────────────────────────────────────────────────── def build_cache(model, processor, loader: DataLoader, split: str, image_token_id: int, sms: int, n_frames: int, chunk_dir: Optional[Path], chunk_size: int, expected_n: Optional[int], resize_short: int = 336) -> Dict[str, torch.Tensor]: start_batch = 0 chunk_idx = 0 if chunk_dir is not None: chunk_dir.mkdir(parents=True, exist_ok=True) n_chunks, n_done = _scan_chunks(chunk_dir) if n_chunks > 0: start_batch = n_chunks * chunk_size chunk_idx = n_chunks logger.info(f" [resume] {n_chunks} chunks ({n_done} samples); " f"skipping first {start_batch} batches") if expected_n is not None and n_done >= expected_n: logger.info(f" [resume] covers all {expected_n}; merging") return _merge_chunks(chunk_dir) acc: Dict[str, List[torch.Tensor]] = {} since_flush = 0 pbar = tqdm(loader, desc=f"cot-cache[{split}]", ncols=80, leave=True) for bi, batch in enumerate(pbar): if bi < start_batch: continue inputs = _build_inputs(processor, batch["images"], batch["metadata"], resize_short=resize_short) feats = extract_batch(model, processor, inputs, image_token_id, sms, n_frames) for k, v in feats.items(): acc.setdefault(k, []).append(v) since_flush += 1 if chunk_dir is not None and since_flush >= chunk_size: n = _flush_chunk(acc, chunk_dir, chunk_idx) pbar.set_postfix_str(f"chunk={chunk_idx} +{n}") acc = {}; since_flush = 0; chunk_idx += 1 if chunk_dir is not None and acc: n = _flush_chunk(acc, chunk_dir, chunk_idx) logger.info(f" [chunk] final flush (+{n})"); acc = {}; chunk_idx += 1 cache = _merge_chunks(chunk_dir) if chunk_dir is not None \ else {k: torch.cat(lst, dim=0) for k, lst in acc.items()} n = next(iter(cache.values())).shape[0] size_gb = sum(t.element_size() * t.numel() for t in cache.values()) / 1e9 logger.info(f" {split}: {n} samples keys={list(cache.keys())} size={size_gb:.2f} GB") return cache # ── main ─────────────────────────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser("make_cot_belief_cache") ap.add_argument("--ckpt_dir", required=True, help="PEFT adapter dir (contains adapter_config.json + tokenizer)") ap.add_argument("--base_model", default="PROJECT_ROOT/models/Qwen3-VL-4B-Instruct") ap.add_argument("--label_dir", default="data/policy_labels") ap.add_argument("--split", default=None, help="Shortcut: read {label_dir}/{split}.json") ap.add_argument("--manifest", default=None, help="Explicit manifest path; overrides --split") ap.add_argument("--out", required=True, help="Output .pt path") ap.add_argument("--n_frames", type=int, default=8, help="Match training (CoT SFT used n_frames=8)") ap.add_argument("--sampling", default="last_biased", choices=["original", "uniform", "last_biased", "last_2s"]) ap.add_argument("--source_filter", default="all", choices=["all", "nexar", "multisrc", "dada", "dad"]) ap.add_argument("--batch_size", type=int, default=1) ap.add_argument("--num_workers", type=int, default=2) ap.add_argument("--chunk_size", type=int, default=2000) ap.add_argument("--keep_chunks", action="store_true") ap.add_argument("--overwrite", action="store_true") ap.add_argument("--resize_short", type=int, default=336, help="Resize PIL short side before feeding processor (match training)") ap.add_argument("--debug", action="store_true") ap.add_argument("--debug_samples", type=int, default=16) args = ap.parse_args() out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) if out_path.exists() and not args.overwrite: logger.info(f"Cache exists: {out_path} — use --overwrite to rebuild"); return if args.manifest is not None: label_path = Path(args.manifest) elif args.split is not None: label_path = Path(args.label_dir) / f"{args.split}.json" else: raise SystemExit("Provide either --split or --manifest") if not label_path.exists(): raise SystemExit(f"manifest not found: {label_path}") # Monkey-patch MAX_FRAMES so dataset preallocates correctly for per-frame mode. import training.Policy.policy_dataset as pds pds.MAX_FRAMES = args.n_frames model, processor = load_model(args.base_model, args.ckpt_dir) img_tok_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") sms = _config_spatial_merge_size(model.config) hidden_dim = _config_hidden_size(model.config) logger.info(f" image_token_id={img_tok_id} spatial_merge_size={sms} hidden_dim={hidden_dim}") split_name = args.split or label_path.stem ds = PolicyDataset( manifests = [label_path], split = split_name, debug = args.debug, debug_samples = args.debug_samples, n_frames = args.n_frames, sampling = args.sampling, source_filter = args.source_filter, ) if len(ds) == 0: raise SystemExit("dataset empty after filtering") loader = DataLoader( ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=policy_collate_fn, pin_memory=True, ) chunk_dir = out_path.parent / (out_path.stem + ".chunks") if args.chunk_size > 0 else None cache = build_cache( model, processor, loader, split_name, image_token_id=img_tok_id, sms=sms, n_frames=args.n_frames, chunk_dir=chunk_dir, chunk_size=args.chunk_size, expected_n=len(ds), resize_short=args.resize_short, ) ids = [s.get("video_id") for s in ds.samples] labels = [int(s.get("action_label", -1)) for s in ds.samples] meta = { "schema_version": SCHEMA_VERSION, "cache_mode": "per_frame_cot_belief", "backbone": "Qwen3-VL-4B-Instruct", "hidden_dim": hidden_dim, "n_frames": args.n_frames, "sampling": args.sampling, "source_filter": args.source_filter, "n_samples": int(next(iter(cache.values())).shape[0]), "spatial_merge_size": sms, "image_token_id": int(img_tok_id), "ckpt_dir": str(args.ckpt_dir), "base_model": str(args.base_model), "label_path": str(label_path), "ids": ids, "action_labels": labels, } to_save = dict(cache) to_save["meta"] = meta tmp = out_path.with_suffix(out_path.suffix + ".tmp") torch.save(to_save, tmp); tmp.rename(out_path) logger.info(f" Saved -> {out_path}") with open(out_path.with_suffix(".meta.json"), "w") as f: slim = {k: v for k, v in meta.items() if k not in ("ids", "action_labels")} slim["n_ids"] = len(ids) json.dump(slim, f, indent=2) if chunk_dir is not None and chunk_dir.exists() and not args.keep_chunks: shutil.rmtree(chunk_dir) logger.info(f" removed {chunk_dir}") logger.info("cot belief cache complete.") if __name__ == "__main__": main()