| |
| """ |
| Small improvements: fine-grained threshold search, TTA-conditioned analysis, |
| and v3 + v5_mono ensemble. |
| |
| Runs entirely on CPU using belief cache + lightweight policy heads. |
| No full model loading needed. |
| |
| Usage: |
| python -m training.Policy.threshold_analysis \ |
| --label_dir data/policy_labels \ |
| --belief_cache_dir data/belief_cache \ |
| --v3_ckpt checkpoints/Policy/policy_warmstart_v3/best \ |
| --v5_ckpt checkpoints/Policy/policy_warmstart_v5_mono/best \ |
| --output_dir eval_results/threshold_analysis |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| import sys |
| sys.path.insert(0, str(Path(__file__).resolve().parents[2])) |
|
|
| from lkalert.models.components import PolicyHead, HierarchicalPolicyHead |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("Policy.threshold") |
|
|
|
|
| def load_val_data(label_dir: str, cache_dir: str): |
| """Load val labels and belief cache.""" |
| with open(Path(label_dir) / "val.json") as f: |
| data = json.load(f) |
| samples = data["samples"] |
|
|
| cache = torch.load(Path(cache_dir) / "val.pt", map_location="cpu", weights_only=True) |
| beliefs = cache["beliefs"] |
| tta_means = cache["tta_means"] |
| tta_vars = cache["tta_vars"] |
|
|
| labels = np.array([s["action_label"] for s in samples]) |
| cats = np.array([s["category"] for s in samples]) |
| ttas = np.array([s["tta_raw"] for s in samples]) |
| vids = [s["video_id"] for s in samples] |
|
|
| return beliefs, tta_means, tta_vars, labels, cats, ttas, vids |
|
|
|
|
| @torch.no_grad() |
| def get_v3_probs(beliefs, tta_means, tta_vars, ckpt_dir): |
| """Forward through v3 PolicyHead β softmax probs [N, 3].""" |
| head = PolicyHead(hidden_dim=int(beliefs.shape[-1])) |
| sd = torch.load(Path(ckpt_dir) / "policy_head.pt", map_location="cpu") |
| head.load_state_dict(sd) |
| head.eval() |
|
|
| B = beliefs.shape[0] |
| prev_action = torch.zeros(B, dtype=torch.long) |
| logits = head(beliefs, tta_means, tta_vars, prev_action) |
| return F.softmax(logits, dim=-1).numpy() |
|
|
|
|
| @torch.no_grad() |
| def get_v5_probs(beliefs, tta_means, tta_vars, ckpt_dir): |
| """Forward through v5 HierarchicalPolicyHead β 3-class probs [N, 3].""" |
| head = HierarchicalPolicyHead(hidden_dim=int(beliefs.shape[-1])) |
| sd = torch.load(Path(ckpt_dir) / "policy_head.pt", map_location="cpu") |
| head.load_state_dict(sd) |
| head.eval() |
|
|
| B = beliefs.shape[0] |
| prev_action = torch.zeros(B, dtype=torch.long) |
| alert_logit, danger_logit = head(beliefs, tta_means, tta_vars, prev_action) |
|
|
| p_alert = torch.sigmoid(alert_logit).numpy() |
| p_danger = torch.sigmoid(danger_logit).numpy() |
| p_silent = 1.0 - p_danger |
| p_observe = np.clip(p_danger - p_alert, 0.0, None) |
|
|
| probs = np.stack([p_silent, p_observe, p_alert], axis=-1) |
| probs = probs / probs.sum(axis=-1, keepdims=True).clip(1e-8) |
| return probs |
|
|
|
|
| def policy_metrics(preds, labels, cats): |
| """Compute PolicyScore and sub-metrics.""" |
| ego_mask = cats == "ego_positive" |
| ne_mask = cats == "non_ego" |
| sn_mask = cats == "safe_neg" |
|
|
| ego_alert_mask = ego_mask & (labels == 2) |
| ego_recall = float((preds[ego_alert_mask] == 2).mean()) if ego_alert_mask.sum() > 0 else 0.0 |
| ne_noalert = float((preds[ne_mask] != 2).mean()) if ne_mask.sum() > 0 else 0.0 |
| sn_silent = float((preds[sn_mask] == 0).mean()) if sn_mask.sum() > 0 else 0.0 |
| sn_alert = float((preds[sn_mask] == 2).mean()) if sn_mask.sum() > 0 else 0.0 |
|
|
| |
| score = 0.65 * ego_recall + 0.25 * sn_silent - 0.15 * sn_alert |
| return { |
| "policy_score": score, |
| "ego_alert_recall": ego_recall, |
| "non_ego_noalert_rate": ne_noalert, |
| "safe_neg_silent_rate": sn_silent, |
| "safe_neg_alert_rate": sn_alert, |
| } |
|
|
|
|
| def binary_ap(probs, labels): |
| """Compute binary AP from P(ALERT).""" |
| from sklearn.metrics import average_precision_score |
| true = (labels == 2).astype(int) |
| return float(average_precision_score(true, probs[:, 2])) if true.sum() > 0 else 0.0 |
|
|
|
|
| |
| |
| |
|
|
| def fine_threshold_grid(probs_v5, labels, cats, raw_alert, raw_danger): |
| """ |
| For v5 hierarchical: search tau_a x tau_d at 0.01 resolution. |
| probs_v5 is 3-class probs, but we reconstruct p_alert and p_danger. |
| """ |
| |
| p_alert = raw_alert |
| p_danger = raw_danger |
|
|
| best_score = -1 |
| best_ta, best_td = 0.5, 0.5 |
| results_grid = [] |
|
|
| for ta in np.arange(0.20, 0.81, 0.01): |
| for td in np.arange(0.10, 0.81, 0.01): |
| preds = np.zeros(len(labels), dtype=int) |
| preds[p_danger > td] = 1 |
| preds[p_alert > ta] = 2 |
| m = policy_metrics(preds, labels, cats) |
| if m["policy_score"] > best_score: |
| best_score = m["policy_score"] |
| best_ta, best_td = ta, td |
| best_m = m |
|
|
| logger.info(f"Fine threshold: best tau_a={best_ta:.2f} tau_d={best_td:.2f} " |
| f"PolicyScore={best_score:.4f}") |
| return { |
| "best_tau_alert": round(best_ta, 2), |
| "best_tau_danger": round(best_td, 2), |
| "best_policy_score": best_score, |
| **{f"best_{k}": v for k, v in best_m.items() if k != "policy_score"}, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def tta_conditioned_analysis(probs, labels, cats, ttas): |
| """Analyze how optimal thresholds vary with TTA.""" |
| p_alert = probs[:, 2] |
|
|
| buckets = [ |
| ("tta_0_2", (0, 2)), |
| ("tta_2_4", (2, 4)), |
| ("tta_4_6", (4, 6)), |
| ("tta_6_inf", (6, 100)), |
| ("no_tta", (-2, -0.5)), |
| ] |
|
|
| results = {} |
| for name, (lo, hi) in buckets: |
| mask = (ttas >= lo) & (ttas < hi) |
| n = mask.sum() |
| if n < 10: |
| continue |
|
|
| sub_labels = labels[mask] |
| sub_cats = cats[mask] |
| sub_palert = p_alert[mask] |
|
|
| |
| best_t, best_s = 0.5, -1 |
| for t in np.arange(0.1, 0.9, 0.01): |
| preds = np.where(sub_palert > t, 2, 0).astype(int) |
| ego_alert = (sub_cats == "ego_positive") & (sub_labels == 2) |
| recall = float((preds[ego_alert] == 2).mean()) if ego_alert.sum() > 0 else 0.0 |
| sn = sub_cats == "safe_neg" |
| silent = float((preds[sn] == 0).mean()) if sn.sum() > 0 else 0.0 |
| score = 0.7 * recall + 0.3 * silent |
| if score > best_s: |
| best_s = score |
| best_t = t |
|
|
| results[name] = { |
| "n_samples": int(n), |
| "n_alert": int((sub_labels == 2).sum()), |
| "best_threshold": round(best_t, 2), |
| "mean_p_alert": float(sub_palert.mean()), |
| "std_p_alert": float(sub_palert.std()), |
| } |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| def ensemble_analysis(probs_v3, probs_v5, labels, cats): |
| """Weighted ensemble of v3 and v5 probabilities.""" |
| results = {} |
|
|
| for w5 in np.arange(0.0, 1.01, 0.1): |
| w3 = 1.0 - w5 |
| ens = w3 * probs_v3 + w5 * probs_v5 |
| preds = ens.argmax(axis=1) |
| m = policy_metrics(preds, labels, cats) |
| ap = binary_ap(ens, labels) |
| key = f"w3={w3:.1f}_w5={w5:.1f}" |
| results[key] = {**m, "binary_ap": ap} |
|
|
| if abs(w5 - 0.5) < 0.01: |
| logger.info(f" Ensemble 50/50: PolicyScore={m['policy_score']:.4f} AP={ap:.4f}") |
|
|
| |
| best_key = max(results, key=lambda k: results[k]["policy_score"]) |
| results["best"] = {"config": best_key, **results[best_key]} |
| logger.info(f" Best ensemble: {best_key} PolicyScore={results[best_key]['policy_score']:.4f}") |
|
|
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser("threshold_analysis") |
| parser.add_argument("--label_dir", default="data/policy_labels") |
| parser.add_argument("--belief_cache_dir", default="data/belief_cache") |
| parser.add_argument("--v3_ckpt", default="checkpoints/Policy/policy_warmstart_v3/best") |
| parser.add_argument("--v5_ckpt", default="checkpoints/Policy/policy_warmstart_v5_mono/best") |
| parser.add_argument("--output_dir", default="eval_results/threshold_analysis") |
| args = parser.parse_args() |
|
|
| logger.info("Loading val data...") |
| beliefs, tta_means, tta_vars, labels, cats, ttas, vids = load_val_data( |
| args.label_dir, args.belief_cache_dir |
| ) |
|
|
| |
| logger.info("Running v3 PolicyHead...") |
| probs_v3 = get_v3_probs(beliefs, tta_means, tta_vars, args.v3_ckpt) |
| m_v3 = policy_metrics(probs_v3.argmax(axis=1), labels, cats) |
| logger.info(f" v3 PolicyScore={m_v3['policy_score']:.4f} AP={binary_ap(probs_v3, labels):.4f}") |
|
|
| logger.info("Running v5 HierarchicalPolicyHead...") |
| |
| head_v5 = HierarchicalPolicyHead(hidden_dim=int(beliefs.shape[-1])) |
| sd = torch.load(Path(args.v5_ckpt) / "policy_head.pt", map_location="cpu") |
| head_v5.load_state_dict(sd) |
| head_v5.eval() |
| with torch.no_grad(): |
| prev = torch.zeros(beliefs.shape[0], dtype=torch.long) |
| al, dl = head_v5(beliefs, tta_means, tta_vars, prev) |
| raw_alert = torch.sigmoid(al).numpy() |
| raw_danger = torch.sigmoid(dl).numpy() |
|
|
| probs_v5 = get_v5_probs(beliefs, tta_means, tta_vars, args.v5_ckpt) |
| m_v5 = policy_metrics(probs_v5.argmax(axis=1), labels, cats) |
| logger.info(f" v5 PolicyScore={m_v5['policy_score']:.4f} AP={binary_ap(probs_v5, labels):.4f}") |
|
|
| all_results = {} |
|
|
| |
| logger.info("\n=== Fine-grained threshold grid (v5) ===") |
| all_results["fine_threshold"] = fine_threshold_grid( |
| probs_v5, labels, cats, raw_alert, raw_danger |
| ) |
|
|
| |
| logger.info("\n=== TTA-conditioned threshold analysis ===") |
| all_results["tta_conditioned"] = tta_conditioned_analysis(probs_v5, labels, cats, ttas) |
| for bucket, info in all_results["tta_conditioned"].items(): |
| logger.info(f" {bucket}: n={info['n_samples']} n_alert={info['n_alert']} " |
| f"best_t={info['best_threshold']} mean_p={info['mean_p_alert']:.3f}") |
|
|
| |
| logger.info("\n=== Ensemble analysis (v3 + v5) ===") |
| all_results["ensemble"] = ensemble_analysis(probs_v3, probs_v5, labels, cats) |
|
|
| |
| all_results["summary"] = { |
| "v3_policy_score": m_v3["policy_score"], |
| "v3_binary_ap": binary_ap(probs_v3, labels), |
| "v5_policy_score": m_v5["policy_score"], |
| "v5_binary_ap": binary_ap(probs_v5, labels), |
| "v5_fine_threshold_score": all_results["fine_threshold"]["best_policy_score"], |
| "ensemble_best_score": all_results["ensemble"]["best"]["policy_score"], |
| "ensemble_best_ap": all_results["ensemble"]["best"]["binary_ap"], |
| } |
|
|
| out_dir = Path(args.output_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| with open(out_dir / "threshold_analysis.json", "w") as f: |
| json.dump(all_results, f, indent=2) |
| logger.info(f"\nResults saved to {out_dir / 'threshold_analysis.json'}") |
|
|
| logger.info("\n" + "=" * 60) |
| logger.info("SUMMARY") |
| logger.info("=" * 60) |
| for k, v in all_results["summary"].items(): |
| logger.info(f" {k}: {v:.4f}") |
| logger.info("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|