| |
| """Test SFT model generation quality. |
| |
| Loads the trained VLAlert SFT checkpoint and generates responses |
| on a few val samples to check: |
| 1. Does the model produce [Analysis] + [Safety Assessment] format? |
| 2. Are <|BELIEF|> tokens present and meaningful? |
| 3. Are action tokens correct relative to GT? |
| 4. Is the reasoning diverse (not template-like)? |
| |
| Usage: |
| python tools/test_sft_generation.py --ckpt checkpoints/vlalert_sft_a/best |
| """ |
| import sys, json, torch, argparse |
| from pathlib import Path |
|
|
| ROOT = Path("PROJECT_ROOT") |
| sys.path.insert(0, str(ROOT)) |
|
|
| |
| import torch.nn as nn |
| from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed |
| def _fast(self, hs): |
| dt = self.proj.weight.dtype |
| if isinstance(self.proj, nn.Conv3d): |
| c = self.proj; od = c.out_channels |
| ind = c.in_channels * c.kernel_size[0] * c.kernel_size[1] * c.kernel_size[2] |
| w = c.weight.detach().reshape(od, ind).contiguous() |
| b = c.bias.detach().clone() if c.bias is not None else None |
| np_l = nn.Linear(ind, od, bias=b is not None) |
| np_l.weight.data.copy_(w) |
| if b is not None: np_l.bias.data.copy_(b) |
| np_l.to(device=c.weight.device, dtype=c.weight.dtype) |
| self.proj = np_l |
| if hs.dim() > 2 or hs.shape[-1] != self.proj.in_features: |
| hs = hs.reshape(-1, self.proj.in_features) |
| return self.proj(hs.to(dtype=dt)) |
| Qwen3VLVisionPatchEmbed.forward = _fast |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ckpt", default="checkpoints/vlalert_sft_a/best") |
| ap.add_argument("--val_jsonl", default="data/cot_corpus_v3/v6_stage_a_val.jsonl") |
| ap.add_argument("--n_samples", type=int, default=5) |
| ap.add_argument("--max_new_tokens", type=int, default=512) |
| args = ap.parse_args() |
|
|
| ckpt = args.ckpt if Path(args.ckpt).is_absolute() else str(ROOT / args.ckpt) |
| val_jsonl = args.val_jsonl if Path(args.val_jsonl).is_absolute() else str(ROOT / args.val_jsonl) |
|
|
| device = "cuda" |
|
|
| |
| print(f"Loading model from {ckpt}...") |
| from transformers import AutoProcessor, AutoModelForImageTextToText |
| from peft import PeftModel |
|
|
| base_model = str(ROOT / "models/Qwen3-VL-4B-Instruct") |
| processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True) |
|
|
| model = AutoModelForImageTextToText.from_pretrained( |
| base_model, torch_dtype=torch.bfloat16, trust_remote_code=True) |
| model.resize_token_embeddings(len(processor.tokenizer)) |
| model = PeftModel.from_pretrained(model, ckpt).to(device) |
| model.eval() |
| print(f"Model loaded. GPU: {torch.cuda.memory_allocated()//1024**2}MB") |
|
|
| |
| from training.VLA.train_vlalert_sft_v3 import load_frames, SYSTEM_PROMPT_V3, user_prompt_v3 |
|
|
| |
| lines = Path(val_jsonl).read_text().strip().split("\n") |
| import random |
| random.seed(42) |
| samples = random.sample(lines, min(args.n_samples, len(lines))) |
|
|
| print(f"\n{'='*80}") |
| print(f" Testing {len(samples)} val samples") |
| print(f"{'='*80}") |
|
|
| results = {"format_ok": 0, "has_belief": 0, "has_action": 0, "total": 0} |
|
|
| for i, line in enumerate(samples): |
| rec = json.loads(line) |
| vid = rec["video_id"] |
| src = rec["source"] |
| gt_actions = rec["actions_per_frame"] |
| gt_beliefs = rec["beliefs_per_frame"] |
| n_frames = rec.get("n_frames", 8) |
|
|
| print(f"\n--- Sample {i+1}: {vid} ({src}) ---") |
| print(f"GT actions: {gt_actions}") |
|
|
| |
| try: |
| frames = load_frames(rec["video_path"], rec["frame_indices"], resize_short=336) |
| except Exception as e: |
| print(f" [SKIP] frame load error: {e}") |
| continue |
|
|
| |
| user_content = [{"type": "image", "image": img} for img in frames] |
| user_content.append({"type": "text", "text": user_prompt_v3(n_frames)}) |
| msgs = [ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT_V3}]}, |
| {"role": "user", "content": user_content}, |
| ] |
| text = processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) |
| inputs = processor(text=[text], images=[frames], return_tensors="pt", |
| padding=True).to(device) |
|
|
| |
| with torch.no_grad(): |
| gen = model.generate( |
| **inputs, |
| max_new_tokens=args.max_new_tokens, |
| do_sample=False, |
| temperature=1.0, |
| pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id, |
| ) |
|
|
| prefix_len = inputs["input_ids"].shape[1] |
| gen_text = processor.tokenizer.decode(gen[0, prefix_len:], skip_special_tokens=False) |
|
|
| |
| has_analysis = "[Analysis]" in gen_text |
| has_assessment = "[Safety Assessment]" in gen_text |
| n_belief_open = gen_text.count("<|BELIEF|>") |
| n_belief_close = gen_text.count("</|BELIEF|>") |
| n_silent = gen_text.count("<|SILENT|>") |
| n_observe = gen_text.count("<|OBSERVE|>") |
| n_alert = gen_text.count("<|ALERT|>") |
|
|
| format_ok = has_analysis and has_assessment and n_belief_open >= 1 |
| has_belief = n_belief_open >= 1 and n_belief_close >= 1 |
| has_action = (n_silent + n_observe + n_alert) >= 1 |
|
|
| results["total"] += 1 |
| if format_ok: results["format_ok"] += 1 |
| if has_belief: results["has_belief"] += 1 |
| if has_action: results["has_action"] += 1 |
|
|
| print(f" Format: [Analysis]={'✓' if has_analysis else '✗'} " |
| f"[Safety Assessment]={'✓' if has_assessment else '✗'}") |
| print(f" Belief tokens: {n_belief_open} open, {n_belief_close} close") |
| print(f" Action tokens: S={n_silent} O={n_observe} A={n_alert}") |
| print(f" --- Generated text (first 500 chars) ---") |
| print(f" {gen_text[:500]}") |
| print(f" --- End ---") |
|
|
| |
| t = results["total"] |
| print(f"\n{'='*80}") |
| print(f" SUMMARY ({t} samples)") |
| print(f"{'='*80}") |
| print(f" Format OK ([Analysis]+[Assessment]+belief): {results['format_ok']}/{t}") |
| print(f" Has belief tokens: {results['has_belief']}/{t}") |
| print(f" Has action tokens: {results['has_action']}/{t}") |
| if t > 0: |
| score = (results['format_ok'] + results['has_belief'] + results['has_action']) / (3 * t) |
| print(f" Overall quality score: {score:.1%}") |
| if score >= 0.8: |
| print(f" → GOOD: Model learned the format well") |
| elif score >= 0.5: |
| print(f" → PARTIAL: Format partially learned, may need more training") |
| else: |
| print(f" → POOR: Model didn't learn the format") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|