| |
| """ |
| make_cot_belief_cache.py |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| Per-frame belief cache extraction for the CoT+BeliefToken Qwen3-VL-4B |
| checkpoint (output of training/VLA/train_cot_belief.py). |
| |
| Why a new script: |
| make_belief_cache_v2.py is glued to PolicyModel / SFTModel, which expect |
| {config.json, vlm_lora/, hazard_head.pt, tta_head.pt}. The CoT+BeliefToken |
| checkpoint has a different layout (pure PEFT adapter; tokenizer extended |
| with 5 new tokens; no aux heads). This script loads the PEFT adapter |
| directly, runs the same per-frame visual-token pooling, and writes a cache |
| identical in schema to the v2 per_frame format, so existing temporal heads |
| (temporal_long, traj_full_long, etc.) can consume it with --hidden_dim 2560. |
| |
| Output schema (matches v2 per_frame): |
| beliefs_frame [N, T, D] fp16 β per-frame pooled visual token hiddens |
| valid_frames [N, T] bool β True where a frame was present |
| beliefs_text [N, D] fp16 β mean of non-image valid tokens |
| tta_means [N] fp32 β zeros (no tta_head on this backbone) |
| tta_vars [N] fp32 β ones (variance placeholder) |
| meta dict β schema_version, hidden_dim, n_frames, ids, labels, ... |
| |
| Usage |
| βββββ |
| python -m training.Policy.make_cot_belief_cache \\ |
| --ckpt_dir checkpoints/VLA/qwen3vl4b_cot_belief/best \\ |
| --base_model models/Qwen3-VL-4B-Instruct \\ |
| --split val \\ |
| --out data/belief_cache_qwen3vl4b_multisrc/val_perframe_t16.pt \\ |
| --n_frames 16 --sampling last_biased --chunk_size 2000 |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import shutil |
| import sys |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.amp import autocast |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[2])) |
|
|
| from peft import PeftModel |
| from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
| from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("Policy.make_cot_belief_cache") |
|
|
| SCHEMA_VERSION = 3 |
|
|
| SYSTEM_PROMPT = ( |
| "You are a driving-safety assistant. Given N dashcam frames (earliest β latest), " |
| "produce a short chain-of-thought analysis and then emit a single risk action token " |
| "wrapped in <|BELIEF|> ... </|BELIEF|>. " |
| "The action is <|ALERT|> (imminent collision < ~1.5s), " |
| "<|OBSERVE|> (near-term threat, ~1.5-4s), or <|SILENT|> (no threat). " |
| "Keep prose minimal; the <|BELIEF|> block is mandatory." |
| ) |
| USER_PROMPT = "Analyze the frames and emit scene analysis + belief block." |
|
|
|
|
| |
|
|
| def load_model(base_model: str, ckpt_dir: str, |
| attn_impl: str = "flash_attention_2") -> Tuple[AutoModelForImageTextToText, AutoProcessor]: |
| logger.info(f"Loading processor (w/ special tokens) from {ckpt_dir}") |
| processor = AutoProcessor.from_pretrained(ckpt_dir, trust_remote_code=True) |
|
|
| logger.info(f"Loading base model {base_model} (bf16)") |
| model = AutoModelForImageTextToText.from_pretrained( |
| base_model, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True, |
| attn_implementation=attn_impl, |
| ) |
| |
| |
| new_vocab = len(processor.tokenizer) |
| if model.get_input_embeddings().weight.shape[0] != new_vocab: |
| logger.info(f"Resizing embeddings: " |
| f"{model.get_input_embeddings().weight.shape[0]} -> {new_vocab}") |
| model.resize_token_embeddings(new_vocab) |
|
|
| logger.info(f"Attaching PEFT adapter from {ckpt_dir}") |
| peft_model = PeftModel.from_pretrained(model, ckpt_dir, is_trainable=False) |
| |
| |
| |
| logger.info(" merging LoRA adapters into base weights (inference-only)") |
| model = peft_model.merge_and_unload() |
| model.eval() |
| model.to("cuda") |
| hs = _config_hidden_size(model.config) |
| logger.info(f" hidden_size = {hs}") |
| return model, processor |
|
|
|
|
| def _config_hidden_size(cfg) -> int: |
| return int(getattr(cfg, "hidden_size", None) or cfg.text_config.hidden_size) |
|
|
|
|
| def _config_spatial_merge_size(cfg) -> int: |
| vc = getattr(cfg, "vision_config", None) |
| return int(getattr(vc, "spatial_merge_size", 2) if vc is not None else 2) |
|
|
|
|
| |
|
|
| def _per_image_token_counts(image_grid_thw: torch.Tensor, sms: int) -> List[int]: |
| sms2 = sms * sms |
| return [int((r[0] * r[1] * r[2]) // sms2) for r in image_grid_thw.tolist()] |
|
|
|
|
| def _split_visual_tokens(hs_b: torch.Tensor, |
| ids_b: torch.Tensor, |
| attn_b: torch.Tensor, |
| igt_b: torch.Tensor, |
| image_token_id: int, |
| sms: int) -> List[torch.Tensor]: |
| """Return list of [count_i, D] per-image hidden slices for one sample.""" |
| valid = attn_b > 0 |
| is_img = (ids_b == image_token_id) & valid |
| positions = torch.nonzero(is_img, as_tuple=False).squeeze(-1) |
| n_img_tokens = int(positions.numel()) |
|
|
| counts = _per_image_token_counts(igt_b, sms) |
| if n_img_tokens != sum(counts): |
| raise RuntimeError( |
| f"image-token count mismatch: {n_img_tokens} vs {sum(counts)} " |
| f"(igt={igt_b.tolist()})" |
| ) |
| chunks: List[torch.Tensor] = [] |
| cursor = 0 |
| for c in counts: |
| chunks.append(hs_b[positions[cursor:cursor + c]]) |
| cursor += c |
| return chunks |
|
|
|
|
| |
|
|
| def _resize_short(img, short: int): |
| w, h = img.size |
| if min(w, h) <= short: |
| return img |
| if w < h: |
| nw = short; nh = int(round(h * (short / w))) |
| else: |
| nh = short; nw = int(round(w * (short / h))) |
| return img.resize((nw, nh)) |
|
|
|
|
| def _build_inputs(processor, images_b: List[List], metadata_b: List[dict], |
| resize_short: int = 336): |
| """Build the same chat template used during CoT+BeliefToken training, |
| but without the assistant turn (we only need the visual tokens). |
| Frames are resized to `resize_short` (matches training default) to keep |
| visual-token counts bounded.""" |
| texts: List[str] = [] |
| images_b_resized = [[_resize_short(img, resize_short) for img in frames] |
| for frames in images_b] |
| for frames in images_b_resized: |
| user_content = [{"type": "image", "image": img} for img in frames] |
| user_content.append({"type": "text", "text": USER_PROMPT}) |
| msgs = [ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, |
| {"role": "user", "content": user_content}, |
| ] |
| texts.append( |
| processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) |
| ) |
| return processor(text=texts, images=images_b_resized, |
| return_tensors="pt", padding=True, truncation=False) |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def extract_batch(model, processor, inputs: Dict[str, torch.Tensor], |
| image_token_id: int, sms: int, n_frames: int, |
| amp_dtype=torch.bfloat16) -> Dict[str, torch.Tensor]: |
| device = next(model.parameters()).device |
| moved: Dict[str, torch.Tensor] = {} |
| for k, v in inputs.items(): |
| if not isinstance(v, torch.Tensor): |
| moved[k] = v; continue |
| if k == "pixel_values": |
| moved[k] = v.to(device, dtype=amp_dtype, non_blocking=True) |
| else: |
| moved[k] = v.to(device, non_blocking=True) |
|
|
| |
| base = model.get_base_model() if hasattr(model, "get_base_model") else model |
| core = getattr(base, "model", None) |
| with autocast(device_type="cuda", dtype=amp_dtype, enabled=True): |
| if core is not None: |
| out = core( |
| input_ids = moved["input_ids"], |
| attention_mask = moved.get("attention_mask"), |
| pixel_values = moved.get("pixel_values"), |
| image_grid_thw = moved.get("image_grid_thw"), |
| use_cache = False, return_dict = True, |
| ) |
| hs = out.last_hidden_state if hasattr(out, "last_hidden_state") else out[0] |
| else: |
| out = base( |
| input_ids = moved["input_ids"], |
| attention_mask = moved.get("attention_mask"), |
| pixel_values = moved.get("pixel_values"), |
| image_grid_thw = moved.get("image_grid_thw"), |
| use_cache = False, return_dict = True, |
| output_hidden_states = True, |
| ) |
| hs = out.hidden_states[-1] |
|
|
| B, _, D = hs.shape |
| attn = moved.get("attention_mask") |
| ids = moved.get("input_ids") |
| igt = moved.get("image_grid_thw") |
|
|
| beliefs_frame = torch.zeros(B, n_frames, D, dtype=torch.float16) |
| valid_frames = torch.zeros(B, n_frames, dtype=torch.bool) |
| beliefs_text = torch.zeros(B, D, dtype=torch.float16) |
|
|
| igt_cursor = 0 |
| for b in range(B): |
| ids_b = ids[b] |
| attn_b = attn[b] if attn is not None else torch.ones_like(ids_b) |
| hs_b = hs[b] |
| valid = attn_b > 0 |
| is_img_b = (ids_b == image_token_id) & valid |
|
|
| |
| x = is_img_b.to(torch.int8) |
| diff = torch.cat([x.new_zeros(1), x[1:] - x[:-1]]) |
| n_imgs = int((diff == 1).sum().item()) |
|
|
| if n_imgs > 0: |
| igt_b = igt[igt_cursor:igt_cursor + n_imgs] |
| igt_cursor += n_imgs |
| chunks = _split_visual_tokens(hs_b, ids_b, attn_b, igt_b, |
| image_token_id, sms) |
| for f in range(min(len(chunks), n_frames)): |
| beliefs_frame[b, f] = chunks[f].float().mean(dim=0).to(torch.float16).cpu() |
| valid_frames[b, f] = True |
|
|
| |
| is_text_b = (~is_img_b) & valid |
| m_text = is_text_b.unsqueeze(-1).to(hs_b.dtype) |
| denom = m_text.sum(dim=0).clamp(min=1e-6) |
| t_mean = (hs_b * m_text).sum(dim=0) / denom |
| beliefs_text[b] = t_mean.to(torch.float16).cpu() |
|
|
| return { |
| "beliefs_frame": beliefs_frame, |
| "valid_frames": valid_frames, |
| "beliefs_text": beliefs_text, |
| |
| "tta_means": torch.zeros(B, dtype=torch.float32), |
| "tta_vars": torch.ones(B, dtype=torch.float32), |
| } |
|
|
|
|
| |
|
|
| def _flush_chunk(acc, chunk_dir: Path, idx: int) -> int: |
| if not acc: |
| return 0 |
| part = {k: torch.cat(v, dim=0) for k, v in acc.items()} |
| n = next(iter(part.values())).shape[0] |
| tmp = chunk_dir / f"chunk_{idx:05d}.pt.tmp" |
| fin = chunk_dir / f"chunk_{idx:05d}.pt" |
| torch.save(part, tmp); tmp.rename(fin) |
| return n |
|
|
|
|
| def _scan_chunks(chunk_dir: Path) -> Tuple[int, int]: |
| if not chunk_dir.exists(): |
| return 0, 0 |
| for t in chunk_dir.glob("*.tmp"): |
| t.unlink(missing_ok=True) |
| files = sorted(chunk_dir.glob("chunk_*.pt")) |
| n_samples = 0 |
| for f in files: |
| try: |
| d = torch.load(f, map_location="cpu", weights_only=True) |
| n_samples += int(next(iter(d.values())).shape[0]) |
| except Exception as e: |
| logger.warning(f" [resume] dropping unreadable chunk {f.name}: {e}") |
| f.unlink(missing_ok=True) |
| return len(list(chunk_dir.glob("chunk_*.pt"))), n_samples |
|
|
|
|
| def _merge_chunks(chunk_dir: Path) -> Dict[str, torch.Tensor]: |
| files = sorted(chunk_dir.glob("chunk_*.pt")) |
| if not files: |
| return {} |
| acc: Dict[str, List[torch.Tensor]] = {} |
| for f in files: |
| d = torch.load(f, map_location="cpu", weights_only=True) |
| for k, v in d.items(): |
| acc.setdefault(k, []).append(v) |
| return {k: torch.cat(lst, dim=0) for k, lst in acc.items()} |
|
|
|
|
| |
|
|
| def build_cache(model, processor, loader: DataLoader, split: str, |
| image_token_id: int, sms: int, n_frames: int, |
| chunk_dir: Optional[Path], chunk_size: int, |
| expected_n: Optional[int], |
| resize_short: int = 336) -> Dict[str, torch.Tensor]: |
| start_batch = 0 |
| chunk_idx = 0 |
| if chunk_dir is not None: |
| chunk_dir.mkdir(parents=True, exist_ok=True) |
| n_chunks, n_done = _scan_chunks(chunk_dir) |
| if n_chunks > 0: |
| start_batch = n_chunks * chunk_size |
| chunk_idx = n_chunks |
| logger.info(f" [resume] {n_chunks} chunks ({n_done} samples); " |
| f"skipping first {start_batch} batches") |
| if expected_n is not None and n_done >= expected_n: |
| logger.info(f" [resume] covers all {expected_n}; merging") |
| return _merge_chunks(chunk_dir) |
|
|
| acc: Dict[str, List[torch.Tensor]] = {} |
| since_flush = 0 |
| pbar = tqdm(loader, desc=f"cot-cache[{split}]", ncols=80, leave=True) |
| for bi, batch in enumerate(pbar): |
| if bi < start_batch: |
| continue |
| inputs = _build_inputs(processor, batch["images"], batch["metadata"], |
| resize_short=resize_short) |
| feats = extract_batch(model, processor, inputs, |
| image_token_id, sms, n_frames) |
| for k, v in feats.items(): |
| acc.setdefault(k, []).append(v) |
| since_flush += 1 |
| if chunk_dir is not None and since_flush >= chunk_size: |
| n = _flush_chunk(acc, chunk_dir, chunk_idx) |
| pbar.set_postfix_str(f"chunk={chunk_idx} +{n}") |
| acc = {}; since_flush = 0; chunk_idx += 1 |
|
|
| if chunk_dir is not None and acc: |
| n = _flush_chunk(acc, chunk_dir, chunk_idx) |
| logger.info(f" [chunk] final flush (+{n})"); acc = {}; chunk_idx += 1 |
|
|
| cache = _merge_chunks(chunk_dir) if chunk_dir is not None \ |
| else {k: torch.cat(lst, dim=0) for k, lst in acc.items()} |
| n = next(iter(cache.values())).shape[0] |
| size_gb = sum(t.element_size() * t.numel() for t in cache.values()) / 1e9 |
| logger.info(f" {split}: {n} samples keys={list(cache.keys())} size={size_gb:.2f} GB") |
| return cache |
|
|
|
|
| |
|
|
| def main(): |
| ap = argparse.ArgumentParser("make_cot_belief_cache") |
| ap.add_argument("--ckpt_dir", required=True, |
| help="PEFT adapter dir (contains adapter_config.json + tokenizer)") |
| ap.add_argument("--base_model", |
| default="PROJECT_ROOT/models/Qwen3-VL-4B-Instruct") |
| ap.add_argument("--label_dir", default="data/policy_labels") |
| ap.add_argument("--split", default=None, |
| help="Shortcut: read {label_dir}/{split}.json") |
| ap.add_argument("--manifest", default=None, |
| help="Explicit manifest path; overrides --split") |
| ap.add_argument("--out", required=True, help="Output .pt path") |
| ap.add_argument("--n_frames", type=int, default=8, |
| help="Match training (CoT SFT used n_frames=8)") |
| ap.add_argument("--sampling", default="last_biased", |
| choices=["original", "uniform", "last_biased", "last_2s"]) |
| ap.add_argument("--source_filter", default="all", |
| choices=["all", "nexar", "multisrc", "dada", "dad"]) |
| ap.add_argument("--batch_size", type=int, default=1) |
| ap.add_argument("--num_workers", type=int, default=2) |
| ap.add_argument("--chunk_size", type=int, default=2000) |
| ap.add_argument("--keep_chunks", action="store_true") |
| ap.add_argument("--overwrite", action="store_true") |
| ap.add_argument("--resize_short", type=int, default=336, |
| help="Resize PIL short side before feeding processor (match training)") |
| ap.add_argument("--debug", action="store_true") |
| ap.add_argument("--debug_samples", type=int, default=16) |
| args = ap.parse_args() |
|
|
| out_path = Path(args.out) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| if out_path.exists() and not args.overwrite: |
| logger.info(f"Cache exists: {out_path} β use --overwrite to rebuild"); return |
|
|
| if args.manifest is not None: |
| label_path = Path(args.manifest) |
| elif args.split is not None: |
| label_path = Path(args.label_dir) / f"{args.split}.json" |
| else: |
| raise SystemExit("Provide either --split or --manifest") |
| if not label_path.exists(): |
| raise SystemExit(f"manifest not found: {label_path}") |
|
|
| |
| import training.Policy.policy_dataset as pds |
| pds.MAX_FRAMES = args.n_frames |
|
|
| model, processor = load_model(args.base_model, args.ckpt_dir) |
| img_tok_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") |
| sms = _config_spatial_merge_size(model.config) |
| hidden_dim = _config_hidden_size(model.config) |
| logger.info(f" image_token_id={img_tok_id} spatial_merge_size={sms} hidden_dim={hidden_dim}") |
|
|
| split_name = args.split or label_path.stem |
| ds = PolicyDataset( |
| manifests = [label_path], |
| split = split_name, |
| debug = args.debug, |
| debug_samples = args.debug_samples, |
| n_frames = args.n_frames, |
| sampling = args.sampling, |
| source_filter = args.source_filter, |
| ) |
| if len(ds) == 0: |
| raise SystemExit("dataset empty after filtering") |
|
|
| loader = DataLoader( |
| ds, batch_size=args.batch_size, shuffle=False, |
| num_workers=args.num_workers, collate_fn=policy_collate_fn, |
| pin_memory=True, |
| ) |
|
|
| chunk_dir = out_path.parent / (out_path.stem + ".chunks") if args.chunk_size > 0 else None |
| cache = build_cache( |
| model, processor, loader, split_name, |
| image_token_id=img_tok_id, sms=sms, n_frames=args.n_frames, |
| chunk_dir=chunk_dir, chunk_size=args.chunk_size, |
| expected_n=len(ds), resize_short=args.resize_short, |
| ) |
|
|
| ids = [s.get("video_id") for s in ds.samples] |
| labels = [int(s.get("action_label", -1)) for s in ds.samples] |
| meta = { |
| "schema_version": SCHEMA_VERSION, |
| "cache_mode": "per_frame_cot_belief", |
| "backbone": "Qwen3-VL-4B-Instruct", |
| "hidden_dim": hidden_dim, |
| "n_frames": args.n_frames, |
| "sampling": args.sampling, |
| "source_filter": args.source_filter, |
| "n_samples": int(next(iter(cache.values())).shape[0]), |
| "spatial_merge_size": sms, |
| "image_token_id": int(img_tok_id), |
| "ckpt_dir": str(args.ckpt_dir), |
| "base_model": str(args.base_model), |
| "label_path": str(label_path), |
| "ids": ids, |
| "action_labels": labels, |
| } |
| to_save = dict(cache) |
| to_save["meta"] = meta |
|
|
| tmp = out_path.with_suffix(out_path.suffix + ".tmp") |
| torch.save(to_save, tmp); tmp.rename(out_path) |
| logger.info(f" Saved -> {out_path}") |
|
|
| with open(out_path.with_suffix(".meta.json"), "w") as f: |
| slim = {k: v for k, v in meta.items() if k not in ("ids", "action_labels")} |
| slim["n_ids"] = len(ids) |
| json.dump(slim, f, indent=2) |
|
|
| if chunk_dir is not None and chunk_dir.exists() and not args.keep_chunks: |
| shutil.rmtree(chunk_dir) |
| logger.info(f" removed {chunk_dir}") |
|
|
| logger.info("cot belief cache complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|