| |
| """ |
| Quick eval: the v3 PolicyHead trained with --merge_observe alert |
| (checkpoints/Policy/policy_binary_obs2alert), evaluated on: |
| • all val |
| • Nexar-only val |
| • DADA-only val |
| |
| The binary head was trained on the FULL dataset with OBSERVE labels remapped |
| to ALERT. Architecturally it is still the v3 3-class PolicyHead (output [B,3]), |
| but effectively class 1 never appears in its training targets. |
| """ |
|
|
| from __future__ import annotations |
| import argparse |
| import json |
| from collections import Counter |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from sklearn.metrics import average_precision_score |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| import sys |
| sys.path.insert(0, str(Path(__file__).resolve().parents[2])) |
|
|
| from lkalert.models.components import PolicyHead |
| from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn |
|
|
|
|
| def safe_ap(y_true, y_score): |
| n_pos = int(np.sum(y_true)) |
| if n_pos == 0 or n_pos == len(y_true): |
| return None |
| return float(average_precision_score(y_true, y_score)) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ckpt", default="checkpoints/Policy/policy_binary_obs2alert/best") |
| ap.add_argument("--label_dir", default="data/policy_labels") |
| ap.add_argument("--cache_dir", default="data/belief_cache") |
| ap.add_argument("--batch_size", type=int, default=512) |
| args = ap.parse_args() |
|
|
| ckpt_dir = Path(args.ckpt) |
| val_ds = PolicyDataset( |
| manifests=[Path(args.label_dir) / "val.json"], |
| split="val", |
| belief_cache_path=Path(args.cache_dir) / "val.pt", |
| ) |
| loader = DataLoader( |
| val_ds, batch_size=args.batch_size, shuffle=False, |
| collate_fn=policy_collate_fn, num_workers=2, pin_memory=True, |
| ) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| hidden_dim = val_ds._cache["beliefs"].shape[-1] |
| head = PolicyHead(hidden_dim=hidden_dim).to(device) |
| head.load_state_dict(torch.load(ckpt_dir / "policy_head.pt", map_location=device)) |
| head.eval() |
|
|
| all_probs = [] |
| prev_action = torch.zeros(1, dtype=torch.long, device=device) |
| with torch.no_grad(): |
| for batch in tqdm(loader, desc="Eval", ncols=80): |
| belief = batch["beliefs"].to(device) |
| tm = batch["tta_means"].to(device) |
| tv = batch["tta_vars"].to(device) |
| pa = torch.zeros(belief.size(0), dtype=torch.long, device=device) |
| logits = head(belief, tm, tv, pa) |
| all_probs.append(F.softmax(logits, dim=-1).cpu().numpy()) |
| probs = np.concatenate(all_probs, axis=0) |
|
|
| labels = np.array([s["action_label"] for s in val_ds.samples], dtype=np.int64) |
| sources = np.array([s.get("source", "?") for s in val_ds.samples], dtype=object) |
|
|
| subsets = { |
| "all": np.ones_like(labels, dtype=bool), |
| "nexar": sources == "nexar", |
| "dada": sources == "dada", |
| } |
|
|
| print("\n" + "═" * 100) |
| print(f" policy_binary_obs2alert (v3 head, OBSERVE→ALERT at train time)") |
| print(f" score: p_alert = softmax[:,2] (effective binary: SILENT vs ALERT∪OBSERVE)") |
| print("═" * 100) |
| print(f"{'subset':<8}{'n':>7} {'strict_AP':>10} {'merged_AP':>10} " |
| f"{'observe_AP':>11} {'class_dist':<22}") |
| print("─" * 100) |
|
|
| out = {} |
| for name, mask in subsets.items(): |
| y = labels[mask] |
| p = probs[mask] |
| n = int(mask.sum()) |
| if n == 0: |
| continue |
| strict = safe_ap((y == 2).astype(int), p[:, 2]) |
| merged = safe_ap((y >= 1).astype(int), p[:, 1] + p[:, 2]) |
| obs = safe_ap((y == 1).astype(int), p[:, 1]) |
| cd = dict(sorted(Counter(int(v) for v in y).items())) |
| fmt = lambda v: "— " if v is None else f"{v:.4f}" |
| print(f"{name:<8}{n:>7} {fmt(strict):>10} {fmt(merged):>10} " |
| f"{fmt(obs):>11} {str(cd):<22}") |
| out[name] = {"n": n, "strict_ap": strict, "merged_ap": merged, |
| "observe_ap": obs, "class_dist": cd} |
|
|
| print("═" * 100) |
| print(" ref — Nexar 2025 winner (MViT-V2-S + 3LC): strict_AP = 0.898 on nexar") |
| print("═" * 100 + "\n") |
|
|
| Path("eval_results").mkdir(exist_ok=True) |
| Path("eval_results/binary_head_nexar.json").write_text(json.dumps(out, indent=2)) |
| print("Saved -> eval_results/binary_head_nexar.json") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|