#!/usr/bin/env python """Test SFT model generation quality. Loads the trained VLAlert SFT checkpoint and generates responses on a few val samples to check: 1. Does the model produce [Analysis] + [Safety Assessment] format? 2. Are <|BELIEF|> tokens present and meaningful? 3. Are action tokens correct relative to GT? 4. Is the reasoning diverse (not template-like)? Usage: python tools/test_sft_generation.py --ckpt checkpoints/vlalert_sft_a/best """ import sys, json, torch, argparse from pathlib import Path ROOT = Path("PROJECT_ROOT") sys.path.insert(0, str(ROOT)) # Conv3d patch import torch.nn as nn from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed def _fast(self, hs): dt = self.proj.weight.dtype if isinstance(self.proj, nn.Conv3d): c = self.proj; od = c.out_channels ind = c.in_channels * c.kernel_size[0] * c.kernel_size[1] * c.kernel_size[2] w = c.weight.detach().reshape(od, ind).contiguous() b = c.bias.detach().clone() if c.bias is not None else None np_l = nn.Linear(ind, od, bias=b is not None) np_l.weight.data.copy_(w) if b is not None: np_l.bias.data.copy_(b) np_l.to(device=c.weight.device, dtype=c.weight.dtype) self.proj = np_l if hs.dim() > 2 or hs.shape[-1] != self.proj.in_features: hs = hs.reshape(-1, self.proj.in_features) return self.proj(hs.to(dtype=dt)) Qwen3VLVisionPatchEmbed.forward = _fast def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", default="checkpoints/vlalert_sft_a/best") ap.add_argument("--val_jsonl", default="data/cot_corpus_v3/v6_stage_a_val.jsonl") ap.add_argument("--n_samples", type=int, default=5) ap.add_argument("--max_new_tokens", type=int, default=512) args = ap.parse_args() ckpt = args.ckpt if Path(args.ckpt).is_absolute() else str(ROOT / args.ckpt) val_jsonl = args.val_jsonl if Path(args.val_jsonl).is_absolute() else str(ROOT / args.val_jsonl) device = "cuda" # Load model print(f"Loading model from {ckpt}...") from transformers import AutoProcessor, AutoModelForImageTextToText from peft import PeftModel base_model = str(ROOT / "models/Qwen3-VL-4B-Instruct") processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True) model = AutoModelForImageTextToText.from_pretrained( base_model, torch_dtype=torch.bfloat16, trust_remote_code=True) model.resize_token_embeddings(len(processor.tokenizer)) model = PeftModel.from_pretrained(model, ckpt).to(device) model.eval() print(f"Model loaded. GPU: {torch.cuda.memory_allocated()//1024**2}MB") # Load frames helper from training.VLA.train_vlalert_sft_v3 import load_frames, SYSTEM_PROMPT_V3, user_prompt_v3 # Load val samples lines = Path(val_jsonl).read_text().strip().split("\n") import random random.seed(42) samples = random.sample(lines, min(args.n_samples, len(lines))) print(f"\n{'='*80}") print(f" Testing {len(samples)} val samples") print(f"{'='*80}") results = {"format_ok": 0, "has_belief": 0, "has_action": 0, "total": 0} for i, line in enumerate(samples): rec = json.loads(line) vid = rec["video_id"] src = rec["source"] gt_actions = rec["actions_per_frame"] gt_beliefs = rec["beliefs_per_frame"] n_frames = rec.get("n_frames", 8) print(f"\n--- Sample {i+1}: {vid} ({src}) ---") print(f"GT actions: {gt_actions}") # Load frames try: frames = load_frames(rec["video_path"], rec["frame_indices"], resize_short=336) except Exception as e: print(f" [SKIP] frame load error: {e}") continue # Build prompt (without assistant) user_content = [{"type": "image", "image": img} for img in frames] user_content.append({"type": "text", "text": user_prompt_v3(n_frames)}) msgs = [ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT_V3}]}, {"role": "user", "content": user_content}, ] text = processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) inputs = processor(text=[text], images=[frames], return_tensors="pt", padding=True).to(device) # Generate with torch.no_grad(): gen = model.generate( **inputs, max_new_tokens=args.max_new_tokens, do_sample=False, temperature=1.0, pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id, ) prefix_len = inputs["input_ids"].shape[1] gen_text = processor.tokenizer.decode(gen[0, prefix_len:], skip_special_tokens=False) # Analyze output has_analysis = "[Analysis]" in gen_text has_assessment = "[Safety Assessment]" in gen_text n_belief_open = gen_text.count("<|BELIEF|>") n_belief_close = gen_text.count("") n_silent = gen_text.count("<|SILENT|>") n_observe = gen_text.count("<|OBSERVE|>") n_alert = gen_text.count("<|ALERT|>") format_ok = has_analysis and has_assessment and n_belief_open >= 1 has_belief = n_belief_open >= 1 and n_belief_close >= 1 has_action = (n_silent + n_observe + n_alert) >= 1 results["total"] += 1 if format_ok: results["format_ok"] += 1 if has_belief: results["has_belief"] += 1 if has_action: results["has_action"] += 1 print(f" Format: [Analysis]={'✓' if has_analysis else '✗'} " f"[Safety Assessment]={'✓' if has_assessment else '✗'}") print(f" Belief tokens: {n_belief_open} open, {n_belief_close} close") print(f" Action tokens: S={n_silent} O={n_observe} A={n_alert}") print(f" --- Generated text (first 500 chars) ---") print(f" {gen_text[:500]}") print(f" --- End ---") # Summary t = results["total"] print(f"\n{'='*80}") print(f" SUMMARY ({t} samples)") print(f"{'='*80}") print(f" Format OK ([Analysis]+[Assessment]+belief): {results['format_ok']}/{t}") print(f" Has belief tokens: {results['has_belief']}/{t}") print(f" Has action tokens: {results['has_action']}/{t}") if t > 0: score = (results['format_ok'] + results['has_belief'] + results['has_action']) / (3 * t) print(f" Overall quality score: {score:.1%}") if score >= 0.8: print(f" → GOOD: Model learned the format well") elif score >= 0.5: print(f" → PARTIAL: Format partially learned, may need more training") else: print(f" → POOR: Model didn't learn the format") if __name__ == "__main__": main()