File size: 4,461 Bytes
1e05592
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python3
"""
Quick eval: the v3 PolicyHead trained with --merge_observe alert
(checkpoints/Policy/policy_binary_obs2alert), evaluated on:
  • all val
  • Nexar-only val
  • DADA-only val

The binary head was trained on the FULL dataset with OBSERVE labels remapped
to ALERT. Architecturally it is still the v3 3-class PolicyHead (output [B,3]),
but effectively class 1 never appears in its training targets.
"""

from __future__ import annotations
import argparse
import json
from collections import Counter
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import average_precision_score
from torch.utils.data import DataLoader
from tqdm import tqdm

import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))

from lkalert.models.components import PolicyHead
from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn


def safe_ap(y_true, y_score):
    n_pos = int(np.sum(y_true))
    if n_pos == 0 or n_pos == len(y_true):
        return None
    return float(average_precision_score(y_true, y_score))


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", default="checkpoints/Policy/policy_binary_obs2alert/best")
    ap.add_argument("--label_dir", default="data/policy_labels")
    ap.add_argument("--cache_dir", default="data/belief_cache")
    ap.add_argument("--batch_size", type=int, default=512)
    args = ap.parse_args()

    ckpt_dir = Path(args.ckpt)
    val_ds = PolicyDataset(
        manifests=[Path(args.label_dir) / "val.json"],
        split="val",
        belief_cache_path=Path(args.cache_dir) / "val.pt",
    )
    loader = DataLoader(
        val_ds, batch_size=args.batch_size, shuffle=False,
        collate_fn=policy_collate_fn, num_workers=2, pin_memory=True,
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    hidden_dim = val_ds._cache["beliefs"].shape[-1]
    head = PolicyHead(hidden_dim=hidden_dim).to(device)
    head.load_state_dict(torch.load(ckpt_dir / "policy_head.pt", map_location=device))
    head.eval()

    all_probs = []
    prev_action = torch.zeros(1, dtype=torch.long, device=device)
    with torch.no_grad():
        for batch in tqdm(loader, desc="Eval", ncols=80):
            belief = batch["beliefs"].to(device)
            tm = batch["tta_means"].to(device)
            tv = batch["tta_vars"].to(device)
            pa = torch.zeros(belief.size(0), dtype=torch.long, device=device)
            logits = head(belief, tm, tv, pa)
            all_probs.append(F.softmax(logits, dim=-1).cpu().numpy())
    probs = np.concatenate(all_probs, axis=0)  # [N, 3]

    labels  = np.array([s["action_label"] for s in val_ds.samples], dtype=np.int64)
    sources = np.array([s.get("source", "?")  for s in val_ds.samples], dtype=object)

    subsets = {
        "all":   np.ones_like(labels, dtype=bool),
        "nexar": sources == "nexar",
        "dada":  sources == "dada",
    }

    print("\n" + "═" * 100)
    print(f"  policy_binary_obs2alert (v3 head, OBSERVE→ALERT at train time)")
    print(f"  score: p_alert = softmax[:,2]   (effective binary: SILENT vs ALERT∪OBSERVE)")
    print("═" * 100)
    print(f"{'subset':<8}{'n':>7}  {'strict_AP':>10}  {'merged_AP':>10}  "
          f"{'observe_AP':>11}  {'class_dist':<22}")
    print("─" * 100)

    out = {}
    for name, mask in subsets.items():
        y = labels[mask]
        p = probs[mask]
        n = int(mask.sum())
        if n == 0:
            continue
        strict = safe_ap((y == 2).astype(int), p[:, 2])
        merged = safe_ap((y >= 1).astype(int), p[:, 1] + p[:, 2])
        obs    = safe_ap((y == 1).astype(int), p[:, 1])
        cd     = dict(sorted(Counter(int(v) for v in y).items()))
        fmt = lambda v: "—   " if v is None else f"{v:.4f}"
        print(f"{name:<8}{n:>7}  {fmt(strict):>10}  {fmt(merged):>10}  "
              f"{fmt(obs):>11}  {str(cd):<22}")
        out[name] = {"n": n, "strict_ap": strict, "merged_ap": merged,
                     "observe_ap": obs, "class_dist": cd}

    print("═" * 100)
    print("  ref — Nexar 2025 winner (MViT-V2-S + 3LC): strict_AP = 0.898 on nexar")
    print("═" * 100 + "\n")

    Path("eval_results").mkdir(exist_ok=True)
    Path("eval_results/binary_head_nexar.json").write_text(json.dumps(out, indent=2))
    print("Saved -> eval_results/binary_head_nexar.json")


if __name__ == "__main__":
    main()