File size: 6,854 Bytes
1e05592 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | #!/usr/bin/env python
"""Test SFT model generation quality.
Loads the trained VLAlert SFT checkpoint and generates responses
on a few val samples to check:
1. Does the model produce [Analysis] + [Safety Assessment] format?
2. Are <|BELIEF|> tokens present and meaningful?
3. Are action tokens correct relative to GT?
4. Is the reasoning diverse (not template-like)?
Usage:
python tools/test_sft_generation.py --ckpt checkpoints/vlalert_sft_a/best
"""
import sys, json, torch, argparse
from pathlib import Path
ROOT = Path("PROJECT_ROOT")
sys.path.insert(0, str(ROOT))
# Conv3d patch
import torch.nn as nn
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed
def _fast(self, hs):
dt = self.proj.weight.dtype
if isinstance(self.proj, nn.Conv3d):
c = self.proj; od = c.out_channels
ind = c.in_channels * c.kernel_size[0] * c.kernel_size[1] * c.kernel_size[2]
w = c.weight.detach().reshape(od, ind).contiguous()
b = c.bias.detach().clone() if c.bias is not None else None
np_l = nn.Linear(ind, od, bias=b is not None)
np_l.weight.data.copy_(w)
if b is not None: np_l.bias.data.copy_(b)
np_l.to(device=c.weight.device, dtype=c.weight.dtype)
self.proj = np_l
if hs.dim() > 2 or hs.shape[-1] != self.proj.in_features:
hs = hs.reshape(-1, self.proj.in_features)
return self.proj(hs.to(dtype=dt))
Qwen3VLVisionPatchEmbed.forward = _fast
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", default="checkpoints/vlalert_sft_a/best")
ap.add_argument("--val_jsonl", default="data/cot_corpus_v3/v6_stage_a_val.jsonl")
ap.add_argument("--n_samples", type=int, default=5)
ap.add_argument("--max_new_tokens", type=int, default=512)
args = ap.parse_args()
ckpt = args.ckpt if Path(args.ckpt).is_absolute() else str(ROOT / args.ckpt)
val_jsonl = args.val_jsonl if Path(args.val_jsonl).is_absolute() else str(ROOT / args.val_jsonl)
device = "cuda"
# Load model
print(f"Loading model from {ckpt}...")
from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import PeftModel
base_model = str(ROOT / "models/Qwen3-VL-4B-Instruct")
processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
base_model, torch_dtype=torch.bfloat16, trust_remote_code=True)
model.resize_token_embeddings(len(processor.tokenizer))
model = PeftModel.from_pretrained(model, ckpt).to(device)
model.eval()
print(f"Model loaded. GPU: {torch.cuda.memory_allocated()//1024**2}MB")
# Load frames helper
from training.VLA.train_vlalert_sft_v3 import load_frames, SYSTEM_PROMPT_V3, user_prompt_v3
# Load val samples
lines = Path(val_jsonl).read_text().strip().split("\n")
import random
random.seed(42)
samples = random.sample(lines, min(args.n_samples, len(lines)))
print(f"\n{'='*80}")
print(f" Testing {len(samples)} val samples")
print(f"{'='*80}")
results = {"format_ok": 0, "has_belief": 0, "has_action": 0, "total": 0}
for i, line in enumerate(samples):
rec = json.loads(line)
vid = rec["video_id"]
src = rec["source"]
gt_actions = rec["actions_per_frame"]
gt_beliefs = rec["beliefs_per_frame"]
n_frames = rec.get("n_frames", 8)
print(f"\n--- Sample {i+1}: {vid} ({src}) ---")
print(f"GT actions: {gt_actions}")
# Load frames
try:
frames = load_frames(rec["video_path"], rec["frame_indices"], resize_short=336)
except Exception as e:
print(f" [SKIP] frame load error: {e}")
continue
# Build prompt (without assistant)
user_content = [{"type": "image", "image": img} for img in frames]
user_content.append({"type": "text", "text": user_prompt_v3(n_frames)})
msgs = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT_V3}]},
{"role": "user", "content": user_content},
]
text = processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
inputs = processor(text=[text], images=[frames], return_tensors="pt",
padding=True).to(device)
# Generate
with torch.no_grad():
gen = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=False,
temperature=1.0,
pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id,
)
prefix_len = inputs["input_ids"].shape[1]
gen_text = processor.tokenizer.decode(gen[0, prefix_len:], skip_special_tokens=False)
# Analyze output
has_analysis = "[Analysis]" in gen_text
has_assessment = "[Safety Assessment]" in gen_text
n_belief_open = gen_text.count("<|BELIEF|>")
n_belief_close = gen_text.count("</|BELIEF|>")
n_silent = gen_text.count("<|SILENT|>")
n_observe = gen_text.count("<|OBSERVE|>")
n_alert = gen_text.count("<|ALERT|>")
format_ok = has_analysis and has_assessment and n_belief_open >= 1
has_belief = n_belief_open >= 1 and n_belief_close >= 1
has_action = (n_silent + n_observe + n_alert) >= 1
results["total"] += 1
if format_ok: results["format_ok"] += 1
if has_belief: results["has_belief"] += 1
if has_action: results["has_action"] += 1
print(f" Format: [Analysis]={'✓' if has_analysis else '✗'} "
f"[Safety Assessment]={'✓' if has_assessment else '✗'}")
print(f" Belief tokens: {n_belief_open} open, {n_belief_close} close")
print(f" Action tokens: S={n_silent} O={n_observe} A={n_alert}")
print(f" --- Generated text (first 500 chars) ---")
print(f" {gen_text[:500]}")
print(f" --- End ---")
# Summary
t = results["total"]
print(f"\n{'='*80}")
print(f" SUMMARY ({t} samples)")
print(f"{'='*80}")
print(f" Format OK ([Analysis]+[Assessment]+belief): {results['format_ok']}/{t}")
print(f" Has belief tokens: {results['has_belief']}/{t}")
print(f" Has action tokens: {results['has_action']}/{t}")
if t > 0:
score = (results['format_ok'] + results['has_belief'] + results['has_action']) / (3 * t)
print(f" Overall quality score: {score:.1%}")
if score >= 0.8:
print(f" → GOOD: Model learned the format well")
elif score >= 0.5:
print(f" → PARTIAL: Format partially learned, may need more training")
else:
print(f" → POOR: Model didn't learn the format")
if __name__ == "__main__":
main()
|