VLAlert / tools /test_sft_generation.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
6.85 kB
#!/usr/bin/env python
"""Test SFT model generation quality.
Loads the trained VLAlert SFT checkpoint and generates responses
on a few val samples to check:
1. Does the model produce [Analysis] + [Safety Assessment] format?
2. Are <|BELIEF|> tokens present and meaningful?
3. Are action tokens correct relative to GT?
4. Is the reasoning diverse (not template-like)?
Usage:
python tools/test_sft_generation.py --ckpt checkpoints/vlalert_sft_a/best
"""
import sys, json, torch, argparse
from pathlib import Path
ROOT = Path("PROJECT_ROOT")
sys.path.insert(0, str(ROOT))
# Conv3d patch
import torch.nn as nn
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed
def _fast(self, hs):
dt = self.proj.weight.dtype
if isinstance(self.proj, nn.Conv3d):
c = self.proj; od = c.out_channels
ind = c.in_channels * c.kernel_size[0] * c.kernel_size[1] * c.kernel_size[2]
w = c.weight.detach().reshape(od, ind).contiguous()
b = c.bias.detach().clone() if c.bias is not None else None
np_l = nn.Linear(ind, od, bias=b is not None)
np_l.weight.data.copy_(w)
if b is not None: np_l.bias.data.copy_(b)
np_l.to(device=c.weight.device, dtype=c.weight.dtype)
self.proj = np_l
if hs.dim() > 2 or hs.shape[-1] != self.proj.in_features:
hs = hs.reshape(-1, self.proj.in_features)
return self.proj(hs.to(dtype=dt))
Qwen3VLVisionPatchEmbed.forward = _fast
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", default="checkpoints/vlalert_sft_a/best")
ap.add_argument("--val_jsonl", default="data/cot_corpus_v3/v6_stage_a_val.jsonl")
ap.add_argument("--n_samples", type=int, default=5)
ap.add_argument("--max_new_tokens", type=int, default=512)
args = ap.parse_args()
ckpt = args.ckpt if Path(args.ckpt).is_absolute() else str(ROOT / args.ckpt)
val_jsonl = args.val_jsonl if Path(args.val_jsonl).is_absolute() else str(ROOT / args.val_jsonl)
device = "cuda"
# Load model
print(f"Loading model from {ckpt}...")
from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import PeftModel
base_model = str(ROOT / "models/Qwen3-VL-4B-Instruct")
processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
base_model, torch_dtype=torch.bfloat16, trust_remote_code=True)
model.resize_token_embeddings(len(processor.tokenizer))
model = PeftModel.from_pretrained(model, ckpt).to(device)
model.eval()
print(f"Model loaded. GPU: {torch.cuda.memory_allocated()//1024**2}MB")
# Load frames helper
from training.VLA.train_vlalert_sft_v3 import load_frames, SYSTEM_PROMPT_V3, user_prompt_v3
# Load val samples
lines = Path(val_jsonl).read_text().strip().split("\n")
import random
random.seed(42)
samples = random.sample(lines, min(args.n_samples, len(lines)))
print(f"\n{'='*80}")
print(f" Testing {len(samples)} val samples")
print(f"{'='*80}")
results = {"format_ok": 0, "has_belief": 0, "has_action": 0, "total": 0}
for i, line in enumerate(samples):
rec = json.loads(line)
vid = rec["video_id"]
src = rec["source"]
gt_actions = rec["actions_per_frame"]
gt_beliefs = rec["beliefs_per_frame"]
n_frames = rec.get("n_frames", 8)
print(f"\n--- Sample {i+1}: {vid} ({src}) ---")
print(f"GT actions: {gt_actions}")
# Load frames
try:
frames = load_frames(rec["video_path"], rec["frame_indices"], resize_short=336)
except Exception as e:
print(f" [SKIP] frame load error: {e}")
continue
# Build prompt (without assistant)
user_content = [{"type": "image", "image": img} for img in frames]
user_content.append({"type": "text", "text": user_prompt_v3(n_frames)})
msgs = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT_V3}]},
{"role": "user", "content": user_content},
]
text = processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
inputs = processor(text=[text], images=[frames], return_tensors="pt",
padding=True).to(device)
# Generate
with torch.no_grad():
gen = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=False,
temperature=1.0,
pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id,
)
prefix_len = inputs["input_ids"].shape[1]
gen_text = processor.tokenizer.decode(gen[0, prefix_len:], skip_special_tokens=False)
# Analyze output
has_analysis = "[Analysis]" in gen_text
has_assessment = "[Safety Assessment]" in gen_text
n_belief_open = gen_text.count("<|BELIEF|>")
n_belief_close = gen_text.count("</|BELIEF|>")
n_silent = gen_text.count("<|SILENT|>")
n_observe = gen_text.count("<|OBSERVE|>")
n_alert = gen_text.count("<|ALERT|>")
format_ok = has_analysis and has_assessment and n_belief_open >= 1
has_belief = n_belief_open >= 1 and n_belief_close >= 1
has_action = (n_silent + n_observe + n_alert) >= 1
results["total"] += 1
if format_ok: results["format_ok"] += 1
if has_belief: results["has_belief"] += 1
if has_action: results["has_action"] += 1
print(f" Format: [Analysis]={'✓' if has_analysis else '✗'} "
f"[Safety Assessment]={'✓' if has_assessment else '✗'}")
print(f" Belief tokens: {n_belief_open} open, {n_belief_close} close")
print(f" Action tokens: S={n_silent} O={n_observe} A={n_alert}")
print(f" --- Generated text (first 500 chars) ---")
print(f" {gen_text[:500]}")
print(f" --- End ---")
# Summary
t = results["total"]
print(f"\n{'='*80}")
print(f" SUMMARY ({t} samples)")
print(f"{'='*80}")
print(f" Format OK ([Analysis]+[Assessment]+belief): {results['format_ok']}/{t}")
print(f" Has belief tokens: {results['has_belief']}/{t}")
print(f" Has action tokens: {results['has_action']}/{t}")
if t > 0:
score = (results['format_ok'] + results['has_belief'] + results['has_action']) / (3 * t)
print(f" Overall quality score: {score:.1%}")
if score >= 0.8:
print(f" → GOOD: Model learned the format well")
elif score >= 0.5:
print(f" → PARTIAL: Format partially learned, may need more training")
else:
print(f" → POOR: Model didn't learn the format")
if __name__ == "__main__":
main()