VLAlert / tools /make_cache_gt_belief.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
9.86 kB
"""Phase D-experimental (C) — Cache extractor that FILLS assistant_text with
GT BELIEF descriptions instead of empty placeholders.
Original v3 cache extracts hidden states with assistant_text =
<|BELIEF|> </|BELIEF|>\n × 8 frames ← empty placeholders
This version fills each block with the GT description from
manifest's beliefs_per_frame field:
<|BELIEF|> lead vehicle drifting </|BELIEF|>\n
<|BELIEF|> side-street vehicle approaching </|BELIEF|>\n ...
Then range-pools the BELIEF span (now contains actual descriptive tokens)
to get features that ARE visually-informed (because text content varies
per-frame and reflects scene description).
Output schema matches make_cache_x_v2.py.
Usage:
python tools/make_cache_gt_belief.py \
--split train_9k_gtb \
--manifest data/cot_corpus_v2/vlalert_x_perframe_v2_train.jsonl
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
# Conv3d→Linear patch
from tools import run_train_cot_belief_fast # noqa: F401
import torch
from tqdm import tqdm
from transformers import AutoProcessor
from transformers.models.qwen3_vl import Qwen3VLForConditionalGeneration
from peft import PeftModel
from training.VLA.cot_belief_dataset import (
BELIEF_OPEN, BELIEF_CLOSE, SYSTEM_PROMPT, USER_PROMPT
)
from training.VLA.frame_utils import sample_frames
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("gtb_cache")
BELIEF_LAYERS = (20, 24, 28, 32)
POLICY_LAYER = 33
@torch.no_grad()
def extract_one(model, proc, frames, beliefs, device,
belief_layers=BELIEF_LAYERS, policy_layer=POLICY_LAYER):
"""Return (belief_feat [8, 10240], policy_feat [8, 2560], valid [8]).
Uses the SAME extraction logic as make_cache_x_v2.py but with
BELIEF placeholders FILLED with the per-frame GT descriptions.
"""
assert len(beliefs) == 8, f"need 8 belief strings, got {len(beliefs)}"
# Fill the placeholder with GT text per frame
assistant_text = "\n".join(
f"{BELIEF_OPEN} {b.strip()} {BELIEF_CLOSE}" for b in beliefs)
user_content = [{"type": "image", "image": img} for img in frames]
user_content.append({"type": "text", "text": USER_PROMPT})
messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": user_content},
{"role": "assistant", "content": [{"type": "text", "text": assistant_text}]},
]
text = proc.apply_chat_template(messages, tokenize=False,
add_generation_prompt=False)
inputs = proc(text=[text], images=[frames], return_tensors="pt",
padding=True, truncation=False, max_length=8192)
inputs = {k: v.to(device) for k, v in inputs.items()}
out = model(**inputs, output_hidden_states=True, return_dict=True)
hs_tuple = out.hidden_states # tuple of [1, T, D]
ids = inputs["input_ids"][0]
attn = inputs["attention_mask"][0].bool()
open_id = proc.tokenizer.convert_tokens_to_ids(BELIEF_OPEN)
close_id = proc.tokenizer.convert_tokens_to_ids(BELIEF_CLOSE)
open_pos = ((ids == open_id) & attn).nonzero(as_tuple=False).flatten().tolist()
close_pos = ((ids == close_id) & attn).nonzero(as_tuple=False).flatten().tolist()
n_blocks = min(len(open_pos), len(close_pos), 8)
D = hs_tuple[-1].shape[-1]
belief_dim = D * len(belief_layers)
belief_feat = torch.zeros(8, belief_dim, dtype=torch.float16, device=device)
policy_feat = torch.zeros(8, D, dtype=torch.float16, device=device)
valid = torch.zeros(8, dtype=torch.bool, device=device)
for f, (o, c) in enumerate(zip(open_pos[:n_blocks], close_pos[:n_blocks])):
if c <= o + 1:
continue
# Range pool over BELIEF span content (now ACTUALLY has descriptive text)
parts = []
for L in belief_layers:
hs = hs_tuple[L][0, o+1:c]
parts.append(hs.mean(dim=0))
belief_feat[f] = torch.cat(parts, dim=-1).to(torch.float16)
# POLICY at </BELIEF> closing token
policy_feat[f] = hs_tuple[policy_layer][0, c].to(torch.float16)
valid[f] = True
return belief_feat.cpu(), policy_feat.cpu(), valid.cpu()
def main():
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--split", required=True)
ap.add_argument("--manifest", type=Path, required=True)
ap.add_argument("--ckpt", type=Path,
default=ROOT / "checkpoints/sft_x_v3/best")
ap.add_argument("--base_model", type=Path,
default=ROOT / "models/Qwen3-VL-4B-Instruct")
ap.add_argument("--tag", default="sft_x_v3")
ap.add_argument("--out_dir", type=Path,
default=ROOT / "data/belief_cache_v3")
ap.add_argument("--limit", type=int, default=0)
ap.add_argument("--window",
choices=["legacy", "sil_wide", "obs_mid", "alr_narrow"],
default="legacy",
help="v4: pick which frame-index array to read from the "
"manifest ({window}_frame_indices). legacy uses the "
"original 'frame_indices' field (v3 behaviour).")
args = ap.parse_args()
args.out_dir.mkdir(parents=True, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"[load] ckpt={args.ckpt}")
proc = AutoProcessor.from_pretrained(str(args.ckpt))
base = Qwen3VLForConditionalGeneration.from_pretrained(
str(args.base_model), dtype=torch.bfloat16, device_map={"": device},
attn_implementation="sdpa")
base.resize_token_embeddings(len(proc.tokenizer))
model = PeftModel.from_pretrained(base, str(args.ckpt)).eval()
logger.info(f"[load] manifest={args.manifest} window={args.window}")
fi_field = "frame_indices" if args.window == "legacy" \
else f"{args.window.split('_')[0]}_frame_indices"
logger.info(f" reading frame indices from field: {fi_field}")
records = []
with args.manifest.open() as f:
for ln in f:
if not ln.strip(): continue
obj = json.loads(ln)
if not obj.get("beliefs_per_frame") or len(obj["beliefs_per_frame"]) != 8:
continue
if fi_field not in obj:
continue
records.append(obj)
if args.limit > 0:
records = records[:args.limit]
N = len(records)
logger.info(f" N={N} (with GT beliefs_per_frame + {fi_field})")
belief_dim = 2560 * len(BELIEF_LAYERS)
out_belief = torch.zeros(N, 8, belief_dim, dtype=torch.float16)
out_policy = torch.zeros(N, 8, 2560, dtype=torch.float16)
out_valid = torch.zeros(N, 8, dtype=torch.bool)
out_actions = torch.zeros(N, 8, dtype=torch.long)
out_danger = torch.zeros(N, 8, dtype=torch.float32)
out_tta = torch.zeros(N, 8, dtype=torch.float32)
out_tick_action = torch.zeros(N, dtype=torch.long)
out_tick_tta = torch.full((N,), -1.0)
# v4 additions
out_prev_action = torch.full((N,), 3, dtype=torch.long)
out_oracle_window = torch.zeros(N, dtype=torch.long)
out_boundary = torch.zeros(N, dtype=torch.bool)
out_category, out_source, out_video_id, out_ids = [], [], [], []
action_map = {"SILENT": 0, "OBSERVE": 1, "ALERT": 2}
failed = 0
for i, r in enumerate(tqdm(records, desc="gtb_cache", ncols=80)):
try:
frames = sample_frames(Path(r["video_path"]),
frame_indices=r[fi_field],
resize_short=336)
except Exception:
failed += 1; continue
bf, pf, v = extract_one(model, proc, frames,
r["beliefs_per_frame"], device)
out_belief[i] = bf
out_policy[i] = pf
out_valid[i] = v
actions_pf = r.get("actions_per_frame", ["SILENT"]*8)
out_actions[i] = torch.tensor(
[action_map.get(a, 0) for a in actions_pf], dtype=torch.long)
out_danger[i] = torch.tensor(r.get("danger_per_frame", [0.0]*8))
out_tta[i] = torch.tensor(r.get("tta_per_frame", [-1.0]*8))
out_tick_action[i] = action_map.get(r.get("tick_action", "SILENT"), 0)
out_tick_tta[i] = float(r.get("tick_tta_raw", -1.0))
# v4 fields (read if present, else default)
out_prev_action[i] = int(r.get("prev_action", 3))
out_oracle_window[i] = int(r.get("oracle_window", 1))
out_boundary[i] = bool(r.get("boundary", False))
out_category.append(r.get("category", ""))
out_source.append(r.get("source", ""))
out_video_id.append(r.get("video_id", ""))
out_ids.append(r.get("id", r.get("video_id", "")))
out_path = args.out_dir / f"{args.tag}__{args.split}.pt"
cache = {
"ids": out_ids,
"belief_content": out_belief,
"policy_position": out_policy,
"valid_frames": out_valid,
"actions_pf": out_actions,
"danger_pf": out_danger,
"tta_pf": out_tta,
"tick_action": out_tick_action,
"tick_tta_raw": out_tick_tta,
"prev_action": out_prev_action,
"oracle_window": out_oracle_window,
"boundary": out_boundary,
"window": args.window,
"category": out_category,
"source": out_source,
"video_id": out_video_id,
"schema": "vlalert_x_v4_gt_belief_fill",
"belief_layers": list(BELIEF_LAYERS),
"policy_layer": POLICY_LAYER,
"ckpt": str(args.ckpt),
}
torch.save(cache, out_path)
logger.info(f"[save] {out_path} failed={failed}")
if __name__ == "__main__":
main()