VLAlert / training /Policy /eval_binary_head_nexar.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
4.46 kB
#!/usr/bin/env python3
"""
Quick eval: the v3 PolicyHead trained with --merge_observe alert
(checkpoints/Policy/policy_binary_obs2alert), evaluated on:
• all val
• Nexar-only val
• DADA-only val
The binary head was trained on the FULL dataset with OBSERVE labels remapped
to ALERT. Architecturally it is still the v3 3-class PolicyHead (output [B,3]),
but effectively class 1 never appears in its training targets.
"""
from __future__ import annotations
import argparse
import json
from collections import Counter
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import average_precision_score
from torch.utils.data import DataLoader
from tqdm import tqdm
import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from lkalert.models.components import PolicyHead
from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn
def safe_ap(y_true, y_score):
n_pos = int(np.sum(y_true))
if n_pos == 0 or n_pos == len(y_true):
return None
return float(average_precision_score(y_true, y_score))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", default="checkpoints/Policy/policy_binary_obs2alert/best")
ap.add_argument("--label_dir", default="data/policy_labels")
ap.add_argument("--cache_dir", default="data/belief_cache")
ap.add_argument("--batch_size", type=int, default=512)
args = ap.parse_args()
ckpt_dir = Path(args.ckpt)
val_ds = PolicyDataset(
manifests=[Path(args.label_dir) / "val.json"],
split="val",
belief_cache_path=Path(args.cache_dir) / "val.pt",
)
loader = DataLoader(
val_ds, batch_size=args.batch_size, shuffle=False,
collate_fn=policy_collate_fn, num_workers=2, pin_memory=True,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
hidden_dim = val_ds._cache["beliefs"].shape[-1]
head = PolicyHead(hidden_dim=hidden_dim).to(device)
head.load_state_dict(torch.load(ckpt_dir / "policy_head.pt", map_location=device))
head.eval()
all_probs = []
prev_action = torch.zeros(1, dtype=torch.long, device=device)
with torch.no_grad():
for batch in tqdm(loader, desc="Eval", ncols=80):
belief = batch["beliefs"].to(device)
tm = batch["tta_means"].to(device)
tv = batch["tta_vars"].to(device)
pa = torch.zeros(belief.size(0), dtype=torch.long, device=device)
logits = head(belief, tm, tv, pa)
all_probs.append(F.softmax(logits, dim=-1).cpu().numpy())
probs = np.concatenate(all_probs, axis=0) # [N, 3]
labels = np.array([s["action_label"] for s in val_ds.samples], dtype=np.int64)
sources = np.array([s.get("source", "?") for s in val_ds.samples], dtype=object)
subsets = {
"all": np.ones_like(labels, dtype=bool),
"nexar": sources == "nexar",
"dada": sources == "dada",
}
print("\n" + "═" * 100)
print(f" policy_binary_obs2alert (v3 head, OBSERVE→ALERT at train time)")
print(f" score: p_alert = softmax[:,2] (effective binary: SILENT vs ALERT∪OBSERVE)")
print("═" * 100)
print(f"{'subset':<8}{'n':>7} {'strict_AP':>10} {'merged_AP':>10} "
f"{'observe_AP':>11} {'class_dist':<22}")
print("─" * 100)
out = {}
for name, mask in subsets.items():
y = labels[mask]
p = probs[mask]
n = int(mask.sum())
if n == 0:
continue
strict = safe_ap((y == 2).astype(int), p[:, 2])
merged = safe_ap((y >= 1).astype(int), p[:, 1] + p[:, 2])
obs = safe_ap((y == 1).astype(int), p[:, 1])
cd = dict(sorted(Counter(int(v) for v in y).items()))
fmt = lambda v: "— " if v is None else f"{v:.4f}"
print(f"{name:<8}{n:>7} {fmt(strict):>10} {fmt(merged):>10} "
f"{fmt(obs):>11} {str(cd):<22}")
out[name] = {"n": n, "strict_ap": strict, "merged_ap": merged,
"observe_ap": obs, "class_dist": cd}
print("═" * 100)
print(" ref — Nexar 2025 winner (MViT-V2-S + 3LC): strict_AP = 0.898 on nexar")
print("═" * 100 + "\n")
Path("eval_results").mkdir(exist_ok=True)
Path("eval_results/binary_head_nexar.json").write_text(json.dumps(out, indent=2))
print("Saved -> eval_results/binary_head_nexar.json")
if __name__ == "__main__":
main()