#!/usr/bin/env python3 """ Small improvements: fine-grained threshold search, TTA-conditioned analysis, and v3 + v5_mono ensemble. Runs entirely on CPU using belief cache + lightweight policy heads. No full model loading needed. Usage: python -m training.Policy.threshold_analysis \ --label_dir data/policy_labels \ --belief_cache_dir data/belief_cache \ --v3_ckpt checkpoints/Policy/policy_warmstart_v3/best \ --v5_ckpt checkpoints/Policy/policy_warmstart_v5_mono/best \ --output_dir eval_results/threshold_analysis """ from __future__ import annotations import argparse import json import logging from collections import defaultdict from pathlib import Path import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm import sys sys.path.insert(0, str(Path(__file__).resolve().parents[2])) from lkalert.models.components import PolicyHead, HierarchicalPolicyHead logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("Policy.threshold") def load_val_data(label_dir: str, cache_dir: str): """Load val labels and belief cache.""" with open(Path(label_dir) / "val.json") as f: data = json.load(f) samples = data["samples"] cache = torch.load(Path(cache_dir) / "val.pt", map_location="cpu", weights_only=True) beliefs = cache["beliefs"] # [N, 2048] tta_means = cache["tta_means"] # [N] tta_vars = cache["tta_vars"] # [N] labels = np.array([s["action_label"] for s in samples]) cats = np.array([s["category"] for s in samples]) ttas = np.array([s["tta_raw"] for s in samples]) vids = [s["video_id"] for s in samples] return beliefs, tta_means, tta_vars, labels, cats, ttas, vids @torch.no_grad() def get_v3_probs(beliefs, tta_means, tta_vars, ckpt_dir): """Forward through v3 PolicyHead → softmax probs [N, 3].""" head = PolicyHead(hidden_dim=int(beliefs.shape[-1])) sd = torch.load(Path(ckpt_dir) / "policy_head.pt", map_location="cpu") head.load_state_dict(sd) head.eval() B = beliefs.shape[0] prev_action = torch.zeros(B, dtype=torch.long) logits = head(beliefs, tta_means, tta_vars, prev_action) return F.softmax(logits, dim=-1).numpy() @torch.no_grad() def get_v5_probs(beliefs, tta_means, tta_vars, ckpt_dir): """Forward through v5 HierarchicalPolicyHead → 3-class probs [N, 3].""" head = HierarchicalPolicyHead(hidden_dim=int(beliefs.shape[-1])) sd = torch.load(Path(ckpt_dir) / "policy_head.pt", map_location="cpu") head.load_state_dict(sd) head.eval() B = beliefs.shape[0] prev_action = torch.zeros(B, dtype=torch.long) alert_logit, danger_logit = head(beliefs, tta_means, tta_vars, prev_action) p_alert = torch.sigmoid(alert_logit).numpy() p_danger = torch.sigmoid(danger_logit).numpy() p_silent = 1.0 - p_danger p_observe = np.clip(p_danger - p_alert, 0.0, None) probs = np.stack([p_silent, p_observe, p_alert], axis=-1) probs = probs / probs.sum(axis=-1, keepdims=True).clip(1e-8) return probs def policy_metrics(preds, labels, cats): """Compute PolicyScore and sub-metrics.""" ego_mask = cats == "ego_positive" ne_mask = cats == "non_ego" sn_mask = cats == "safe_neg" ego_alert_mask = ego_mask & (labels == 2) ego_recall = float((preds[ego_alert_mask] == 2).mean()) if ego_alert_mask.sum() > 0 else 0.0 ne_noalert = float((preds[ne_mask] != 2).mean()) if ne_mask.sum() > 0 else 0.0 sn_silent = float((preds[sn_mask] == 0).mean()) if sn_mask.sum() > 0 else 0.0 sn_alert = float((preds[sn_mask] == 2).mean()) if sn_mask.sum() > 0 else 0.0 # PolicyScore v3 (safety-first): 0.65 * ego_recall + 0.25 * safe_silent - 0.15 * safe_alert score = 0.65 * ego_recall + 0.25 * sn_silent - 0.15 * sn_alert return { "policy_score": score, "ego_alert_recall": ego_recall, "non_ego_noalert_rate": ne_noalert, "safe_neg_silent_rate": sn_silent, "safe_neg_alert_rate": sn_alert, } def binary_ap(probs, labels): """Compute binary AP from P(ALERT).""" from sklearn.metrics import average_precision_score true = (labels == 2).astype(int) return float(average_precision_score(true, probs[:, 2])) if true.sum() > 0 else 0.0 # ═══════════════════════════════════════════════════════════════════════════════ # Analysis 1: Fine-grained threshold grid for v5 # ═══════════════════════════════════════════════════════════════════════════════ def fine_threshold_grid(probs_v5, labels, cats, raw_alert, raw_danger): """ For v5 hierarchical: search tau_a x tau_d at 0.01 resolution. probs_v5 is 3-class probs, but we reconstruct p_alert and p_danger. """ # Reconstruct from 3-class probs (approximate) p_alert = raw_alert p_danger = raw_danger best_score = -1 best_ta, best_td = 0.5, 0.5 results_grid = [] for ta in np.arange(0.20, 0.81, 0.01): for td in np.arange(0.10, 0.81, 0.01): preds = np.zeros(len(labels), dtype=int) preds[p_danger > td] = 1 preds[p_alert > ta] = 2 m = policy_metrics(preds, labels, cats) if m["policy_score"] > best_score: best_score = m["policy_score"] best_ta, best_td = ta, td best_m = m logger.info(f"Fine threshold: best tau_a={best_ta:.2f} tau_d={best_td:.2f} " f"PolicyScore={best_score:.4f}") return { "best_tau_alert": round(best_ta, 2), "best_tau_danger": round(best_td, 2), "best_policy_score": best_score, **{f"best_{k}": v for k, v in best_m.items() if k != "policy_score"}, } # ═══════════════════════════════════════════════════════════════════════════════ # Analysis 2: TTA-conditioned thresholds # ═══════════════════════════════════════════════════════════════════════════════ def tta_conditioned_analysis(probs, labels, cats, ttas): """Analyze how optimal thresholds vary with TTA.""" p_alert = probs[:, 2] buckets = [ ("tta_0_2", (0, 2)), ("tta_2_4", (2, 4)), ("tta_4_6", (4, 6)), ("tta_6_inf", (6, 100)), ("no_tta", (-2, -0.5)), # safe_neg and non_ego with tta=-1 ] results = {} for name, (lo, hi) in buckets: mask = (ttas >= lo) & (ttas < hi) n = mask.sum() if n < 10: continue sub_labels = labels[mask] sub_cats = cats[mask] sub_palert = p_alert[mask] # Find best threshold for this bucket best_t, best_s = 0.5, -1 for t in np.arange(0.1, 0.9, 0.01): preds = np.where(sub_palert > t, 2, 0).astype(int) ego_alert = (sub_cats == "ego_positive") & (sub_labels == 2) recall = float((preds[ego_alert] == 2).mean()) if ego_alert.sum() > 0 else 0.0 sn = sub_cats == "safe_neg" silent = float((preds[sn] == 0).mean()) if sn.sum() > 0 else 0.0 score = 0.7 * recall + 0.3 * silent if score > best_s: best_s = score best_t = t results[name] = { "n_samples": int(n), "n_alert": int((sub_labels == 2).sum()), "best_threshold": round(best_t, 2), "mean_p_alert": float(sub_palert.mean()), "std_p_alert": float(sub_palert.std()), } return results # ═══════════════════════════════════════════════════════════════════════════════ # Analysis 3: Ensemble v3 + v5 # ═══════════════════════════════════════════════════════════════════════════════ def ensemble_analysis(probs_v3, probs_v5, labels, cats): """Weighted ensemble of v3 and v5 probabilities.""" results = {} for w5 in np.arange(0.0, 1.01, 0.1): w3 = 1.0 - w5 ens = w3 * probs_v3 + w5 * probs_v5 preds = ens.argmax(axis=1) m = policy_metrics(preds, labels, cats) ap = binary_ap(ens, labels) key = f"w3={w3:.1f}_w5={w5:.1f}" results[key] = {**m, "binary_ap": ap} if abs(w5 - 0.5) < 0.01: logger.info(f" Ensemble 50/50: PolicyScore={m['policy_score']:.4f} AP={ap:.4f}") # Find best best_key = max(results, key=lambda k: results[k]["policy_score"]) results["best"] = {"config": best_key, **results[best_key]} logger.info(f" Best ensemble: {best_key} PolicyScore={results[best_key]['policy_score']:.4f}") return results def main(): parser = argparse.ArgumentParser("threshold_analysis") parser.add_argument("--label_dir", default="data/policy_labels") parser.add_argument("--belief_cache_dir", default="data/belief_cache") parser.add_argument("--v3_ckpt", default="checkpoints/Policy/policy_warmstart_v3/best") parser.add_argument("--v5_ckpt", default="checkpoints/Policy/policy_warmstart_v5_mono/best") parser.add_argument("--output_dir", default="eval_results/threshold_analysis") args = parser.parse_args() logger.info("Loading val data...") beliefs, tta_means, tta_vars, labels, cats, ttas, vids = load_val_data( args.label_dir, args.belief_cache_dir ) # Get predictions from both models logger.info("Running v3 PolicyHead...") probs_v3 = get_v3_probs(beliefs, tta_means, tta_vars, args.v3_ckpt) m_v3 = policy_metrics(probs_v3.argmax(axis=1), labels, cats) logger.info(f" v3 PolicyScore={m_v3['policy_score']:.4f} AP={binary_ap(probs_v3, labels):.4f}") logger.info("Running v5 HierarchicalPolicyHead...") # Also get raw sigmoid outputs for threshold analysis head_v5 = HierarchicalPolicyHead(hidden_dim=int(beliefs.shape[-1])) sd = torch.load(Path(args.v5_ckpt) / "policy_head.pt", map_location="cpu") head_v5.load_state_dict(sd) head_v5.eval() with torch.no_grad(): prev = torch.zeros(beliefs.shape[0], dtype=torch.long) al, dl = head_v5(beliefs, tta_means, tta_vars, prev) raw_alert = torch.sigmoid(al).numpy() raw_danger = torch.sigmoid(dl).numpy() probs_v5 = get_v5_probs(beliefs, tta_means, tta_vars, args.v5_ckpt) m_v5 = policy_metrics(probs_v5.argmax(axis=1), labels, cats) logger.info(f" v5 PolicyScore={m_v5['policy_score']:.4f} AP={binary_ap(probs_v5, labels):.4f}") all_results = {} # ── 1) Fine-grained threshold grid ── logger.info("\n=== Fine-grained threshold grid (v5) ===") all_results["fine_threshold"] = fine_threshold_grid( probs_v5, labels, cats, raw_alert, raw_danger ) # ── 2) TTA-conditioned analysis ── logger.info("\n=== TTA-conditioned threshold analysis ===") all_results["tta_conditioned"] = tta_conditioned_analysis(probs_v5, labels, cats, ttas) for bucket, info in all_results["tta_conditioned"].items(): logger.info(f" {bucket}: n={info['n_samples']} n_alert={info['n_alert']} " f"best_t={info['best_threshold']} mean_p={info['mean_p_alert']:.3f}") # ── 3) Ensemble ── logger.info("\n=== Ensemble analysis (v3 + v5) ===") all_results["ensemble"] = ensemble_analysis(probs_v3, probs_v5, labels, cats) # ── 4) Summary ── all_results["summary"] = { "v3_policy_score": m_v3["policy_score"], "v3_binary_ap": binary_ap(probs_v3, labels), "v5_policy_score": m_v5["policy_score"], "v5_binary_ap": binary_ap(probs_v5, labels), "v5_fine_threshold_score": all_results["fine_threshold"]["best_policy_score"], "ensemble_best_score": all_results["ensemble"]["best"]["policy_score"], "ensemble_best_ap": all_results["ensemble"]["best"]["binary_ap"], } out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) with open(out_dir / "threshold_analysis.json", "w") as f: json.dump(all_results, f, indent=2) logger.info(f"\nResults saved to {out_dir / 'threshold_analysis.json'}") logger.info("\n" + "=" * 60) logger.info("SUMMARY") logger.info("=" * 60) for k, v in all_results["summary"].items(): logger.info(f" {k}: {v:.4f}") logger.info("=" * 60) if __name__ == "__main__": main()