File size: 6,854 Bytes
1e05592
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#!/usr/bin/env python
"""Test SFT model generation quality.

Loads the trained VLAlert SFT checkpoint and generates responses
on a few val samples to check:
1. Does the model produce [Analysis] + [Safety Assessment] format?
2. Are <|BELIEF|> tokens present and meaningful?
3. Are action tokens correct relative to GT?
4. Is the reasoning diverse (not template-like)?

Usage:
    python tools/test_sft_generation.py --ckpt checkpoints/vlalert_sft_a/best
"""
import sys, json, torch, argparse
from pathlib import Path

ROOT = Path("PROJECT_ROOT")
sys.path.insert(0, str(ROOT))

# Conv3d patch
import torch.nn as nn
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed
def _fast(self, hs):
    dt = self.proj.weight.dtype
    if isinstance(self.proj, nn.Conv3d):
        c = self.proj; od = c.out_channels
        ind = c.in_channels * c.kernel_size[0] * c.kernel_size[1] * c.kernel_size[2]
        w = c.weight.detach().reshape(od, ind).contiguous()
        b = c.bias.detach().clone() if c.bias is not None else None
        np_l = nn.Linear(ind, od, bias=b is not None)
        np_l.weight.data.copy_(w)
        if b is not None: np_l.bias.data.copy_(b)
        np_l.to(device=c.weight.device, dtype=c.weight.dtype)
        self.proj = np_l
    if hs.dim() > 2 or hs.shape[-1] != self.proj.in_features:
        hs = hs.reshape(-1, self.proj.in_features)
    return self.proj(hs.to(dtype=dt))
Qwen3VLVisionPatchEmbed.forward = _fast


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", default="checkpoints/vlalert_sft_a/best")
    ap.add_argument("--val_jsonl", default="data/cot_corpus_v3/v6_stage_a_val.jsonl")
    ap.add_argument("--n_samples", type=int, default=5)
    ap.add_argument("--max_new_tokens", type=int, default=512)
    args = ap.parse_args()

    ckpt = args.ckpt if Path(args.ckpt).is_absolute() else str(ROOT / args.ckpt)
    val_jsonl = args.val_jsonl if Path(args.val_jsonl).is_absolute() else str(ROOT / args.val_jsonl)

    device = "cuda"

    # Load model
    print(f"Loading model from {ckpt}...")
    from transformers import AutoProcessor, AutoModelForImageTextToText
    from peft import PeftModel

    base_model = str(ROOT / "models/Qwen3-VL-4B-Instruct")
    processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True)

    model = AutoModelForImageTextToText.from_pretrained(
        base_model, torch_dtype=torch.bfloat16, trust_remote_code=True)
    model.resize_token_embeddings(len(processor.tokenizer))
    model = PeftModel.from_pretrained(model, ckpt).to(device)
    model.eval()
    print(f"Model loaded. GPU: {torch.cuda.memory_allocated()//1024**2}MB")

    # Load frames helper
    from training.VLA.train_vlalert_sft_v3 import load_frames, SYSTEM_PROMPT_V3, user_prompt_v3

    # Load val samples
    lines = Path(val_jsonl).read_text().strip().split("\n")
    import random
    random.seed(42)
    samples = random.sample(lines, min(args.n_samples, len(lines)))

    print(f"\n{'='*80}")
    print(f"  Testing {len(samples)} val samples")
    print(f"{'='*80}")

    results = {"format_ok": 0, "has_belief": 0, "has_action": 0, "total": 0}

    for i, line in enumerate(samples):
        rec = json.loads(line)
        vid = rec["video_id"]
        src = rec["source"]
        gt_actions = rec["actions_per_frame"]
        gt_beliefs = rec["beliefs_per_frame"]
        n_frames = rec.get("n_frames", 8)

        print(f"\n--- Sample {i+1}: {vid} ({src}) ---")
        print(f"GT actions: {gt_actions}")

        # Load frames
        try:
            frames = load_frames(rec["video_path"], rec["frame_indices"], resize_short=336)
        except Exception as e:
            print(f"  [SKIP] frame load error: {e}")
            continue

        # Build prompt (without assistant)
        user_content = [{"type": "image", "image": img} for img in frames]
        user_content.append({"type": "text", "text": user_prompt_v3(n_frames)})
        msgs = [
            {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT_V3}]},
            {"role": "user", "content": user_content},
        ]
        text = processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False)
        inputs = processor(text=[text], images=[frames], return_tensors="pt",
                            padding=True).to(device)

        # Generate
        with torch.no_grad():
            gen = model.generate(
                **inputs,
                max_new_tokens=args.max_new_tokens,
                do_sample=False,
                temperature=1.0,
                pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id,
            )

        prefix_len = inputs["input_ids"].shape[1]
        gen_text = processor.tokenizer.decode(gen[0, prefix_len:], skip_special_tokens=False)

        # Analyze output
        has_analysis = "[Analysis]" in gen_text
        has_assessment = "[Safety Assessment]" in gen_text
        n_belief_open = gen_text.count("<|BELIEF|>")
        n_belief_close = gen_text.count("</|BELIEF|>")
        n_silent = gen_text.count("<|SILENT|>")
        n_observe = gen_text.count("<|OBSERVE|>")
        n_alert = gen_text.count("<|ALERT|>")

        format_ok = has_analysis and has_assessment and n_belief_open >= 1
        has_belief = n_belief_open >= 1 and n_belief_close >= 1
        has_action = (n_silent + n_observe + n_alert) >= 1

        results["total"] += 1
        if format_ok: results["format_ok"] += 1
        if has_belief: results["has_belief"] += 1
        if has_action: results["has_action"] += 1

        print(f"  Format: [Analysis]={'✓' if has_analysis else '✗'} "
              f"[Safety Assessment]={'✓' if has_assessment else '✗'}")
        print(f"  Belief tokens: {n_belief_open} open, {n_belief_close} close")
        print(f"  Action tokens: S={n_silent} O={n_observe} A={n_alert}")
        print(f"  --- Generated text (first 500 chars) ---")
        print(f"  {gen_text[:500]}")
        print(f"  --- End ---")

    # Summary
    t = results["total"]
    print(f"\n{'='*80}")
    print(f"  SUMMARY ({t} samples)")
    print(f"{'='*80}")
    print(f"  Format OK ([Analysis]+[Assessment]+belief): {results['format_ok']}/{t}")
    print(f"  Has belief tokens: {results['has_belief']}/{t}")
    print(f"  Has action tokens: {results['has_action']}/{t}")
    if t > 0:
        score = (results['format_ok'] + results['has_belief'] + results['has_action']) / (3 * t)
        print(f"  Overall quality score: {score:.1%}")
        if score >= 0.8:
            print(f"  → GOOD: Model learned the format well")
        elif score >= 0.5:
            print(f"  → PARTIAL: Format partially learned, may need more training")
        else:
            print(f"  → POOR: Model didn't learn the format")


if __name__ == "__main__":
    main()