VLAlert / training /Policy /make_cot_belief_cache.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
21.3 kB
#!/usr/bin/env python3
"""
make_cot_belief_cache.py
═══════════════════════════════════════════════════════════════════════════════
Per-frame belief cache extraction for the CoT+BeliefToken Qwen3-VL-4B
checkpoint (output of training/VLA/train_cot_belief.py).
Why a new script:
make_belief_cache_v2.py is glued to PolicyModel / SFTModel, which expect
{config.json, vlm_lora/, hazard_head.pt, tta_head.pt}. The CoT+BeliefToken
checkpoint has a different layout (pure PEFT adapter; tokenizer extended
with 5 new tokens; no aux heads). This script loads the PEFT adapter
directly, runs the same per-frame visual-token pooling, and writes a cache
identical in schema to the v2 per_frame format, so existing temporal heads
(temporal_long, traj_full_long, etc.) can consume it with --hidden_dim 2560.
Output schema (matches v2 per_frame):
beliefs_frame [N, T, D] fp16 β€” per-frame pooled visual token hiddens
valid_frames [N, T] bool β€” True where a frame was present
beliefs_text [N, D] fp16 β€” mean of non-image valid tokens
tta_means [N] fp32 β€” zeros (no tta_head on this backbone)
tta_vars [N] fp32 β€” ones (variance placeholder)
meta dict β€” schema_version, hidden_dim, n_frames, ids, labels, ...
Usage
─────
python -m training.Policy.make_cot_belief_cache \\
--ckpt_dir checkpoints/VLA/qwen3vl4b_cot_belief/best \\
--base_model models/Qwen3-VL-4B-Instruct \\
--split val \\
--out data/belief_cache_qwen3vl4b_multisrc/val_perframe_t16.pt \\
--n_frames 16 --sampling last_biased --chunk_size 2000
"""
from __future__ import annotations
import argparse
import json
import logging
import shutil
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.amp import autocast
from torch.utils.data import DataLoader
from tqdm import tqdm
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from peft import PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("Policy.make_cot_belief_cache")
SCHEMA_VERSION = 3 # bumped: Qwen3-VL-4B + CoT+BeliefToken backbone
SYSTEM_PROMPT = (
"You are a driving-safety assistant. Given N dashcam frames (earliest β†’ latest), "
"produce a short chain-of-thought analysis and then emit a single risk action token "
"wrapped in <|BELIEF|> ... </|BELIEF|>. "
"The action is <|ALERT|> (imminent collision < ~1.5s), "
"<|OBSERVE|> (near-term threat, ~1.5-4s), or <|SILENT|> (no threat). "
"Keep prose minimal; the <|BELIEF|> block is mandatory."
)
USER_PROMPT = "Analyze the frames and emit scene analysis + belief block."
# ── model loader ────────────────────────────────────────────────────────────
def load_model(base_model: str, ckpt_dir: str,
attn_impl: str = "flash_attention_2") -> Tuple[AutoModelForImageTextToText, AutoProcessor]:
logger.info(f"Loading processor (w/ special tokens) from {ckpt_dir}")
processor = AutoProcessor.from_pretrained(ckpt_dir, trust_remote_code=True)
logger.info(f"Loading base model {base_model} (bf16)")
model = AutoModelForImageTextToText.from_pretrained(
base_model,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
attn_implementation=attn_impl,
)
# Resize to match extended vocab so PEFT adapter's modules_to_save
# (embed_tokens, lm_head) can be loaded cleanly.
new_vocab = len(processor.tokenizer)
if model.get_input_embeddings().weight.shape[0] != new_vocab:
logger.info(f"Resizing embeddings: "
f"{model.get_input_embeddings().weight.shape[0]} -> {new_vocab}")
model.resize_token_embeddings(new_vocab)
logger.info(f"Attaching PEFT adapter from {ckpt_dir}")
peft_model = PeftModel.from_pretrained(model, ckpt_dir, is_trainable=False)
# Merge LoRA into base weights for much faster inference (LoRA forward has
# ~2-3Γ— overhead per attn/mlp layer). modules_to_save (embed_tokens, lm_head)
# are kept as-is after merge.
logger.info(" merging LoRA adapters into base weights (inference-only)")
model = peft_model.merge_and_unload()
model.eval()
model.to("cuda")
hs = _config_hidden_size(model.config)
logger.info(f" hidden_size = {hs}")
return model, processor
def _config_hidden_size(cfg) -> int:
return int(getattr(cfg, "hidden_size", None) or cfg.text_config.hidden_size)
def _config_spatial_merge_size(cfg) -> int:
vc = getattr(cfg, "vision_config", None)
return int(getattr(vc, "spatial_merge_size", 2) if vc is not None else 2)
# ── per-frame token splitting (mirrors v2) ──────────────────────────────────
def _per_image_token_counts(image_grid_thw: torch.Tensor, sms: int) -> List[int]:
sms2 = sms * sms
return [int((r[0] * r[1] * r[2]) // sms2) for r in image_grid_thw.tolist()]
def _split_visual_tokens(hs_b: torch.Tensor,
ids_b: torch.Tensor,
attn_b: torch.Tensor,
igt_b: torch.Tensor,
image_token_id: int,
sms: int) -> List[torch.Tensor]:
"""Return list of [count_i, D] per-image hidden slices for one sample."""
valid = attn_b > 0
is_img = (ids_b == image_token_id) & valid
positions = torch.nonzero(is_img, as_tuple=False).squeeze(-1)
n_img_tokens = int(positions.numel())
counts = _per_image_token_counts(igt_b, sms)
if n_img_tokens != sum(counts):
raise RuntimeError(
f"image-token count mismatch: {n_img_tokens} vs {sum(counts)} "
f"(igt={igt_b.tolist()})"
)
chunks: List[torch.Tensor] = []
cursor = 0
for c in counts:
chunks.append(hs_b[positions[cursor:cursor + c]])
cursor += c
return chunks
# ── input builder ──────────────────────────────────────────────────────────
def _resize_short(img, short: int):
w, h = img.size
if min(w, h) <= short:
return img
if w < h:
nw = short; nh = int(round(h * (short / w)))
else:
nh = short; nw = int(round(w * (short / h)))
return img.resize((nw, nh))
def _build_inputs(processor, images_b: List[List], metadata_b: List[dict],
resize_short: int = 336):
"""Build the same chat template used during CoT+BeliefToken training,
but without the assistant turn (we only need the visual tokens).
Frames are resized to `resize_short` (matches training default) to keep
visual-token counts bounded."""
texts: List[str] = []
images_b_resized = [[_resize_short(img, resize_short) for img in frames]
for frames in images_b]
for frames in images_b_resized:
user_content = [{"type": "image", "image": img} for img in frames]
user_content.append({"type": "text", "text": USER_PROMPT})
msgs = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": user_content},
]
texts.append(
processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
)
return processor(text=texts, images=images_b_resized,
return_tensors="pt", padding=True, truncation=False)
# ── extract one batch ──────────────────────────────────────────────────────
@torch.no_grad()
def extract_batch(model, processor, inputs: Dict[str, torch.Tensor],
image_token_id: int, sms: int, n_frames: int,
amp_dtype=torch.bfloat16) -> Dict[str, torch.Tensor]:
device = next(model.parameters()).device
moved: Dict[str, torch.Tensor] = {}
for k, v in inputs.items():
if not isinstance(v, torch.Tensor):
moved[k] = v; continue
if k == "pixel_values":
moved[k] = v.to(device, dtype=amp_dtype, non_blocking=True)
else:
moved[k] = v.to(device, non_blocking=True)
# After merge_and_unload() model is a plain HF model; otherwise it's PeftModel.
base = model.get_base_model() if hasattr(model, "get_base_model") else model
core = getattr(base, "model", None)
with autocast(device_type="cuda", dtype=amp_dtype, enabled=True):
if core is not None:
out = core(
input_ids = moved["input_ids"],
attention_mask = moved.get("attention_mask"),
pixel_values = moved.get("pixel_values"),
image_grid_thw = moved.get("image_grid_thw"),
use_cache = False, return_dict = True,
)
hs = out.last_hidden_state if hasattr(out, "last_hidden_state") else out[0]
else:
out = base(
input_ids = moved["input_ids"],
attention_mask = moved.get("attention_mask"),
pixel_values = moved.get("pixel_values"),
image_grid_thw = moved.get("image_grid_thw"),
use_cache = False, return_dict = True,
output_hidden_states = True,
)
hs = out.hidden_states[-1]
B, _, D = hs.shape
attn = moved.get("attention_mask")
ids = moved.get("input_ids")
igt = moved.get("image_grid_thw")
beliefs_frame = torch.zeros(B, n_frames, D, dtype=torch.float16)
valid_frames = torch.zeros(B, n_frames, dtype=torch.bool)
beliefs_text = torch.zeros(B, D, dtype=torch.float16)
igt_cursor = 0
for b in range(B):
ids_b = ids[b]
attn_b = attn[b] if attn is not None else torch.ones_like(ids_b)
hs_b = hs[b]
valid = attn_b > 0
is_img_b = (ids_b == image_token_id) & valid
# Count contiguous image-token runs = number of images in this sample
x = is_img_b.to(torch.int8)
diff = torch.cat([x.new_zeros(1), x[1:] - x[:-1]])
n_imgs = int((diff == 1).sum().item())
if n_imgs > 0:
igt_b = igt[igt_cursor:igt_cursor + n_imgs]
igt_cursor += n_imgs
chunks = _split_visual_tokens(hs_b, ids_b, attn_b, igt_b,
image_token_id, sms)
for f in range(min(len(chunks), n_frames)):
beliefs_frame[b, f] = chunks[f].float().mean(dim=0).to(torch.float16).cpu()
valid_frames[b, f] = True
# Text pool: non-image valid tokens
is_text_b = (~is_img_b) & valid
m_text = is_text_b.unsqueeze(-1).to(hs_b.dtype)
denom = m_text.sum(dim=0).clamp(min=1e-6)
t_mean = (hs_b * m_text).sum(dim=0) / denom
beliefs_text[b] = t_mean.to(torch.float16).cpu()
return {
"beliefs_frame": beliefs_frame,
"valid_frames": valid_frames,
"beliefs_text": beliefs_text,
# tta placeholders β€” shape matches v2 schema
"tta_means": torch.zeros(B, dtype=torch.float32),
"tta_vars": torch.ones(B, dtype=torch.float32),
}
# ── chunked save/resume (mirrors v2 helpers) ───────────────────────────────
def _flush_chunk(acc, chunk_dir: Path, idx: int) -> int:
if not acc:
return 0
part = {k: torch.cat(v, dim=0) for k, v in acc.items()}
n = next(iter(part.values())).shape[0]
tmp = chunk_dir / f"chunk_{idx:05d}.pt.tmp"
fin = chunk_dir / f"chunk_{idx:05d}.pt"
torch.save(part, tmp); tmp.rename(fin)
return n
def _scan_chunks(chunk_dir: Path) -> Tuple[int, int]:
if not chunk_dir.exists():
return 0, 0
for t in chunk_dir.glob("*.tmp"):
t.unlink(missing_ok=True)
files = sorted(chunk_dir.glob("chunk_*.pt"))
n_samples = 0
for f in files:
try:
d = torch.load(f, map_location="cpu", weights_only=True)
n_samples += int(next(iter(d.values())).shape[0])
except Exception as e:
logger.warning(f" [resume] dropping unreadable chunk {f.name}: {e}")
f.unlink(missing_ok=True)
return len(list(chunk_dir.glob("chunk_*.pt"))), n_samples
def _merge_chunks(chunk_dir: Path) -> Dict[str, torch.Tensor]:
files = sorted(chunk_dir.glob("chunk_*.pt"))
if not files:
return {}
acc: Dict[str, List[torch.Tensor]] = {}
for f in files:
d = torch.load(f, map_location="cpu", weights_only=True)
for k, v in d.items():
acc.setdefault(k, []).append(v)
return {k: torch.cat(lst, dim=0) for k, lst in acc.items()}
# ── build cache ────────────────────────────────────────────────────────────
def build_cache(model, processor, loader: DataLoader, split: str,
image_token_id: int, sms: int, n_frames: int,
chunk_dir: Optional[Path], chunk_size: int,
expected_n: Optional[int],
resize_short: int = 336) -> Dict[str, torch.Tensor]:
start_batch = 0
chunk_idx = 0
if chunk_dir is not None:
chunk_dir.mkdir(parents=True, exist_ok=True)
n_chunks, n_done = _scan_chunks(chunk_dir)
if n_chunks > 0:
start_batch = n_chunks * chunk_size
chunk_idx = n_chunks
logger.info(f" [resume] {n_chunks} chunks ({n_done} samples); "
f"skipping first {start_batch} batches")
if expected_n is not None and n_done >= expected_n:
logger.info(f" [resume] covers all {expected_n}; merging")
return _merge_chunks(chunk_dir)
acc: Dict[str, List[torch.Tensor]] = {}
since_flush = 0
pbar = tqdm(loader, desc=f"cot-cache[{split}]", ncols=80, leave=True)
for bi, batch in enumerate(pbar):
if bi < start_batch:
continue
inputs = _build_inputs(processor, batch["images"], batch["metadata"],
resize_short=resize_short)
feats = extract_batch(model, processor, inputs,
image_token_id, sms, n_frames)
for k, v in feats.items():
acc.setdefault(k, []).append(v)
since_flush += 1
if chunk_dir is not None and since_flush >= chunk_size:
n = _flush_chunk(acc, chunk_dir, chunk_idx)
pbar.set_postfix_str(f"chunk={chunk_idx} +{n}")
acc = {}; since_flush = 0; chunk_idx += 1
if chunk_dir is not None and acc:
n = _flush_chunk(acc, chunk_dir, chunk_idx)
logger.info(f" [chunk] final flush (+{n})"); acc = {}; chunk_idx += 1
cache = _merge_chunks(chunk_dir) if chunk_dir is not None \
else {k: torch.cat(lst, dim=0) for k, lst in acc.items()}
n = next(iter(cache.values())).shape[0]
size_gb = sum(t.element_size() * t.numel() for t in cache.values()) / 1e9
logger.info(f" {split}: {n} samples keys={list(cache.keys())} size={size_gb:.2f} GB")
return cache
# ── main ───────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser("make_cot_belief_cache")
ap.add_argument("--ckpt_dir", required=True,
help="PEFT adapter dir (contains adapter_config.json + tokenizer)")
ap.add_argument("--base_model",
default="PROJECT_ROOT/models/Qwen3-VL-4B-Instruct")
ap.add_argument("--label_dir", default="data/policy_labels")
ap.add_argument("--split", default=None,
help="Shortcut: read {label_dir}/{split}.json")
ap.add_argument("--manifest", default=None,
help="Explicit manifest path; overrides --split")
ap.add_argument("--out", required=True, help="Output .pt path")
ap.add_argument("--n_frames", type=int, default=8,
help="Match training (CoT SFT used n_frames=8)")
ap.add_argument("--sampling", default="last_biased",
choices=["original", "uniform", "last_biased", "last_2s"])
ap.add_argument("--source_filter", default="all",
choices=["all", "nexar", "multisrc", "dada", "dad"])
ap.add_argument("--batch_size", type=int, default=1)
ap.add_argument("--num_workers", type=int, default=2)
ap.add_argument("--chunk_size", type=int, default=2000)
ap.add_argument("--keep_chunks", action="store_true")
ap.add_argument("--overwrite", action="store_true")
ap.add_argument("--resize_short", type=int, default=336,
help="Resize PIL short side before feeding processor (match training)")
ap.add_argument("--debug", action="store_true")
ap.add_argument("--debug_samples", type=int, default=16)
args = ap.parse_args()
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
if out_path.exists() and not args.overwrite:
logger.info(f"Cache exists: {out_path} β€” use --overwrite to rebuild"); return
if args.manifest is not None:
label_path = Path(args.manifest)
elif args.split is not None:
label_path = Path(args.label_dir) / f"{args.split}.json"
else:
raise SystemExit("Provide either --split or --manifest")
if not label_path.exists():
raise SystemExit(f"manifest not found: {label_path}")
# Monkey-patch MAX_FRAMES so dataset preallocates correctly for per-frame mode.
import training.Policy.policy_dataset as pds
pds.MAX_FRAMES = args.n_frames
model, processor = load_model(args.base_model, args.ckpt_dir)
img_tok_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
sms = _config_spatial_merge_size(model.config)
hidden_dim = _config_hidden_size(model.config)
logger.info(f" image_token_id={img_tok_id} spatial_merge_size={sms} hidden_dim={hidden_dim}")
split_name = args.split or label_path.stem
ds = PolicyDataset(
manifests = [label_path],
split = split_name,
debug = args.debug,
debug_samples = args.debug_samples,
n_frames = args.n_frames,
sampling = args.sampling,
source_filter = args.source_filter,
)
if len(ds) == 0:
raise SystemExit("dataset empty after filtering")
loader = DataLoader(
ds, batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers, collate_fn=policy_collate_fn,
pin_memory=True,
)
chunk_dir = out_path.parent / (out_path.stem + ".chunks") if args.chunk_size > 0 else None
cache = build_cache(
model, processor, loader, split_name,
image_token_id=img_tok_id, sms=sms, n_frames=args.n_frames,
chunk_dir=chunk_dir, chunk_size=args.chunk_size,
expected_n=len(ds), resize_short=args.resize_short,
)
ids = [s.get("video_id") for s in ds.samples]
labels = [int(s.get("action_label", -1)) for s in ds.samples]
meta = {
"schema_version": SCHEMA_VERSION,
"cache_mode": "per_frame_cot_belief",
"backbone": "Qwen3-VL-4B-Instruct",
"hidden_dim": hidden_dim,
"n_frames": args.n_frames,
"sampling": args.sampling,
"source_filter": args.source_filter,
"n_samples": int(next(iter(cache.values())).shape[0]),
"spatial_merge_size": sms,
"image_token_id": int(img_tok_id),
"ckpt_dir": str(args.ckpt_dir),
"base_model": str(args.base_model),
"label_path": str(label_path),
"ids": ids,
"action_labels": labels,
}
to_save = dict(cache)
to_save["meta"] = meta
tmp = out_path.with_suffix(out_path.suffix + ".tmp")
torch.save(to_save, tmp); tmp.rename(out_path)
logger.info(f" Saved -> {out_path}")
with open(out_path.with_suffix(".meta.json"), "w") as f:
slim = {k: v for k, v in meta.items() if k not in ("ids", "action_labels")}
slim["n_ids"] = len(ids)
json.dump(slim, f, indent=2)
if chunk_dir is not None and chunk_dir.exists() and not args.keep_chunks:
shutil.rmtree(chunk_dir)
logger.info(f" removed {chunk_dir}")
logger.info("cot belief cache complete.")
if __name__ == "__main__":
main()