#!/usr/bin/env python3 """ Quick eval: the v3 PolicyHead trained with --merge_observe alert (checkpoints/Policy/policy_binary_obs2alert), evaluated on: • all val • Nexar-only val • DADA-only val The binary head was trained on the FULL dataset with OBSERVE labels remapped to ALERT. Architecturally it is still the v3 3-class PolicyHead (output [B,3]), but effectively class 1 never appears in its training targets. """ from __future__ import annotations import argparse import json from collections import Counter from pathlib import Path import numpy as np import torch import torch.nn.functional as F from sklearn.metrics import average_precision_score from torch.utils.data import DataLoader from tqdm import tqdm import sys sys.path.insert(0, str(Path(__file__).resolve().parents[2])) from lkalert.models.components import PolicyHead from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn def safe_ap(y_true, y_score): n_pos = int(np.sum(y_true)) if n_pos == 0 or n_pos == len(y_true): return None return float(average_precision_score(y_true, y_score)) def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", default="checkpoints/Policy/policy_binary_obs2alert/best") ap.add_argument("--label_dir", default="data/policy_labels") ap.add_argument("--cache_dir", default="data/belief_cache") ap.add_argument("--batch_size", type=int, default=512) args = ap.parse_args() ckpt_dir = Path(args.ckpt) val_ds = PolicyDataset( manifests=[Path(args.label_dir) / "val.json"], split="val", belief_cache_path=Path(args.cache_dir) / "val.pt", ) loader = DataLoader( val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=policy_collate_fn, num_workers=2, pin_memory=True, ) device = "cuda" if torch.cuda.is_available() else "cpu" hidden_dim = val_ds._cache["beliefs"].shape[-1] head = PolicyHead(hidden_dim=hidden_dim).to(device) head.load_state_dict(torch.load(ckpt_dir / "policy_head.pt", map_location=device)) head.eval() all_probs = [] prev_action = torch.zeros(1, dtype=torch.long, device=device) with torch.no_grad(): for batch in tqdm(loader, desc="Eval", ncols=80): belief = batch["beliefs"].to(device) tm = batch["tta_means"].to(device) tv = batch["tta_vars"].to(device) pa = torch.zeros(belief.size(0), dtype=torch.long, device=device) logits = head(belief, tm, tv, pa) all_probs.append(F.softmax(logits, dim=-1).cpu().numpy()) probs = np.concatenate(all_probs, axis=0) # [N, 3] labels = np.array([s["action_label"] for s in val_ds.samples], dtype=np.int64) sources = np.array([s.get("source", "?") for s in val_ds.samples], dtype=object) subsets = { "all": np.ones_like(labels, dtype=bool), "nexar": sources == "nexar", "dada": sources == "dada", } print("\n" + "═" * 100) print(f" policy_binary_obs2alert (v3 head, OBSERVE→ALERT at train time)") print(f" score: p_alert = softmax[:,2] (effective binary: SILENT vs ALERT∪OBSERVE)") print("═" * 100) print(f"{'subset':<8}{'n':>7} {'strict_AP':>10} {'merged_AP':>10} " f"{'observe_AP':>11} {'class_dist':<22}") print("─" * 100) out = {} for name, mask in subsets.items(): y = labels[mask] p = probs[mask] n = int(mask.sum()) if n == 0: continue strict = safe_ap((y == 2).astype(int), p[:, 2]) merged = safe_ap((y >= 1).astype(int), p[:, 1] + p[:, 2]) obs = safe_ap((y == 1).astype(int), p[:, 1]) cd = dict(sorted(Counter(int(v) for v in y).items())) fmt = lambda v: "— " if v is None else f"{v:.4f}" print(f"{name:<8}{n:>7} {fmt(strict):>10} {fmt(merged):>10} " f"{fmt(obs):>11} {str(cd):<22}") out[name] = {"n": n, "strict_ap": strict, "merged_ap": merged, "observe_ap": obs, "class_dist": cd} print("═" * 100) print(" ref — Nexar 2025 winner (MViT-V2-S + 3LC): strict_AP = 0.898 on nexar") print("═" * 100 + "\n") Path("eval_results").mkdir(exist_ok=True) Path("eval_results/binary_head_nexar.json").write_text(json.dumps(out, indent=2)) print("Saved -> eval_results/binary_head_nexar.json") if __name__ == "__main__": main()