#!/usr/bin/env python3 """ Optimize LKAlert post-processing hyper-parameters (per-class T, temporal smoothing, non-ego bias, global ALERT bias) to maximize PolicyScore subject to recall constraints — without any retraining. Pipeline ──────── 1. Extract raw LKAlert logits for train + val splits (cached to disk). 2. Fit per-class temperature on train logits. 3. Grid-search (smooth_alpha, non_ego_alpha, alert_bias) on val. 4. Pick best operating point under constraint recall_ego ≥ R_min. 5. Print a before/after table vs baselines (read from existing all_results.json). 6. Emit driver-intuitive metrics (crashes caught / FA per hour / seconds saved). 7. Save best config + metrics JSON. Outputs ─────── eval_results/paper_comparison/postproc_best.json — chosen config + metrics eval_results/paper_comparison/logits_cache/ — cached logits eval_results/paper_comparison/driver_metrics.json — intuitive metrics Usage ───── python -m training.Policy.optimize_postproc \ --sft_checkpoint checkpoints/SFT/sft_v2/best \ --lkalert_ckpt checkpoints/Policy/policy_warmstart_v3/best \ --label_dir data/policy_labels \ --belief_cache_dir data/belief_cache \ --baseline_json eval_results/paper_comparison/all_results.json \ --output_dir eval_results/paper_comparison \ --recall_min 0.65 """ from __future__ import annotations import argparse import json import logging import sys from itertools import product from pathlib import Path from typing import Dict, List import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader from tqdm import tqdm ROOT = Path(__file__).resolve().parents[2] sys.path.insert(0, str(ROOT)) from training.Policy.policy_model import PolicyModel from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn from training.Policy.postproc import ( fit_per_class_temperature, apply_per_class_temperature, temporal_smooth, non_ego_bias as apply_non_ego_bias, compute_metrics, ) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("Policy.optimize_postproc") # ────────────────────────────── Logits extraction ───────────────────────────── @torch.no_grad() def extract_logits( sft_ckpt: Path, policy_ckpt: Path, label_json: Path, cache_pt: Path, split: str, device, ) -> Dict[str, np.ndarray]: """Run PolicyModel once; collect logits + meta into numpy arrays.""" ds = PolicyDataset(manifests=[label_json], split=split, belief_cache_path=cache_pt) loader = DataLoader(ds, batch_size=512, shuffle=False, num_workers=4, collate_fn=policy_collate_fn) model = PolicyModel(str(sft_ckpt), use_bf16=True) model.load_policy_checkpoint(str(policy_ckpt)) model.eval() L, lbl, cat, tta, vids = [], [], [], [], [] for b in tqdm(loader, desc=f"extract[{split}]"): lg = model.forward_cached( b["beliefs"].to(device), b["tta_means"].to(device), b["tta_vars"].to(device), ) L.append(lg.cpu().numpy()) lbl += b["action_labels"].tolist() cat += b["categories"] tta += b["tta_raws"].tolist() vids += b["video_ids"] del model torch.cuda.empty_cache() return { "logits": np.concatenate(L).astype(np.float32), "labels": np.array(lbl, dtype=np.int64), "categories": np.array(cat), "ttas": np.array(tta, dtype=np.float32), "video_ids": np.array(vids), } def load_or_extract(cache_dir: Path, split: str, sft_ckpt, policy_ckpt, label_dir, belief_dir, device, force: bool = False) -> Dict[str, np.ndarray]: cache_dir.mkdir(parents=True, exist_ok=True) cache_file = cache_dir / f"{split}_lkalert.npz" if cache_file.exists() and not force: logger.info(f"Loading cached logits from {cache_file}") npz = np.load(cache_file, allow_pickle=True) return {k: npz[k] for k in npz.files} data = extract_logits( Path(sft_ckpt), Path(policy_ckpt), Path(label_dir) / f"{split}.json", Path(belief_dir) / f"{split}.pt", split, device, ) np.savez_compressed(cache_file, **data) logger.info(f"Saved logits cache → {cache_file} (shape={data['logits'].shape})") return data # ────────────────────────────── Grid search ─────────────────────────────────── def stratified_split_indices(labels: np.ndarray, frac: float, seed: int = 0): """Stratified index split: returns (idx_A, idx_B) where idx_A has `frac` of each class.""" rng = np.random.default_rng(seed) idx_A, idx_B = [], [] for c in np.unique(labels): pool = np.where(labels == c)[0] rng.shuffle(pool) k = int(round(frac * len(pool))) idx_A.append(pool[:k]) idx_B.append(pool[k:]) return np.concatenate(idx_A), np.concatenate(idx_B) def _apply(logits, video_ids, ttas, calib, smooth_alpha, non_ego_alpha, alert_bias): """Apply one (calibration, smoothing, non_ego_bias, global_bias) combo.""" lg = logits.astype(np.float32).copy() kind = calib["kind"] if kind == "global_T": lg = lg / float(calib["T"]) elif kind == "per_class_T": lg = apply_per_class_temperature(lg, calib["T"]) # kind == "raw": no-op if smooth_alpha < 1.0: # only when EMA actually mixes past frames lg = temporal_smooth(lg, video_ids, ttas, window=1, mode="ema", alpha=smooth_alpha) if non_ego_alpha > 0.0: lg = apply_non_ego_bias(lg, alpha=non_ego_alpha) if alert_bias != 0.0: lg[:, 2] = lg[:, 2] + alert_bias return lg def grid_search(val: Dict[str, np.ndarray], calib_opts: List[dict]) -> List[dict]: """ Sweep (calibration × smooth_alpha × non_ego_alpha × alert_bias). Returns list of {cfg, metrics} dicts — no filtering applied here. """ smooth_alphas = [1.0, 0.8, 0.6, 0.4] # 1.0 = no smoothing non_ego_alphas = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0] biases = np.round(np.linspace(-1.0, 1.5, 26), 3).tolist() total = len(calib_opts) * len(smooth_alphas) * len(non_ego_alphas) * len(biases) logger.info(f" Grid size: {total} configurations") results = [] for calib in calib_opts: for sa in smooth_alphas: for na in non_ego_alphas: for ab in biases: lg = _apply(val["logits"], val["video_ids"], val["ttas"], calib, sa, na, ab) m = compute_metrics(lg, val["labels"], val["categories"], val["ttas"], val["video_ids"]) results.append({ "cfg": {"calib": calib["name"], "smooth_alpha": sa, "non_ego_alpha": na, "alert_bias": ab}, "metrics": m, }) return results def filter_configs(results: List[dict], recall_min: float, fa_max: float, burden_max: float) -> List[dict]: return [r for r in results if r["metrics"]["ego_alert_recall"] >= recall_min and r["metrics"]["safe_neg_alert_leak"] <= fa_max and r["metrics"]["burden_non_ego"] <= burden_max] def pick_recommended(all_cfgs: List[dict], recall_min: float, fa_max: float, burden_max: float, raw_metrics: dict) -> Dict[str, dict]: """ Return up to three recommended configs: balanced : highest PolicyScore under all three constraints low_fa : under recall_min, minimize FA (FA dominance vs baselines) high_recall : under fa_max & burden_max, maximize recall (capture more crashes) Each is a (cfg, metrics) dict; fields may be None if no qualifier matches. """ out: Dict[str, dict] = {} # balanced — all constraints + max PolicyScore kept = filter_configs(all_cfgs, recall_min, fa_max, burden_max) if kept: kept.sort(key=lambda r: -r["metrics"]["policy_score"]) out["balanced"] = kept[0] # low_fa — just recall constraint, pick lowest FA, tiebreak by recall fa_pool = [r for r in all_cfgs if r["metrics"]["ego_alert_recall"] >= recall_min] if fa_pool: fa_pool.sort(key=lambda r: (r["metrics"]["safe_neg_alert_leak"], -r["metrics"]["ego_alert_recall"])) out["low_fa"] = fa_pool[0] # high_recall — FA & burden bounded, pick max recall hr_pool = [r for r in all_cfgs if r["metrics"]["safe_neg_alert_leak"] <= fa_max and r["metrics"]["burden_non_ego"] <= burden_max] if hr_pool: hr_pool.sort(key=lambda r: (-r["metrics"]["ego_alert_recall"], r["metrics"]["safe_neg_alert_leak"])) out["high_recall"] = hr_pool[0] return out # ────────────────────────────── Driver-intuitive metrics ────────────────────── def driver_metrics(m: dict, fps: float = 1.0, window_s: float = 1.0) -> dict: """ Translate raw metrics into phrases a reviewer/engineer can grasp. Assumes each sample covers `window_s` seconds of driving (default 1 s). - crashes_caught_per_100 = 100 × ego_alert_recall - false_alerts_per_hour = fa_rate × (3600 / window_s) - non_ego_noise_per_hour = burden_non_ego × (3600 / window_s) - silent_when_safe = safe_neg_silent (fraction) - mean_lead_seconds = lead_time_mean (OBSERVE∪ALERT) - alert_lead_seconds = alert_lead_time_mean (ALERT only) - warning_coverage = lead_time_coverage (fraction of collision videos warned) """ per_hour = 3600.0 / window_s return { "crashes_caught_per_100": round(100.0 * m["ego_alert_recall"], 1), "false_alerts_per_hour": round(m["safe_neg_alert_leak"] * per_hour, 0), "non_ego_noise_per_hour": round(m["burden_non_ego"] * per_hour, 0), "silent_when_safe": round(100.0 * m["safe_neg_silent"], 1), "mean_lead_seconds": round(m["lead_time_mean"], 2), "alert_lead_seconds": round(m["alert_lead_time_mean"], 2), "warning_coverage_pct": round(100.0 * m["lead_time_coverage"], 1), "alert_coverage_pct": round(100.0 * m["alert_lead_time_coverage"], 1), } # ────────────────────────────── Reports ─────────────────────────────────────── METRICS_KEYS = [ ("policy_score", "PolicyScore↑"), ("ego_alert_recall", "EgoRec↑"), ("safe_neg_alert_leak", "FA↓"), ("burden_non_ego", "Burden↓"), ("binary_ap", "AP↑"), ("binary_f1", "F1↑"), ("lead_time_mean", "Lead(s)↑"), ("lead_time_coverage", "Cov↑"), ] def print_row(name: str, m: dict, star: bool = False): label = ("★ " + name) if star else (" " + name) cells = [f"{label:<30}"] for k, _ in METRICS_KEYS: v = m.get(k, 0.0) if "coverage" in k or "recall" in k or "noalert" in k \ or k in ("safe_neg_alert_leak", "burden_non_ego", "binary_ap", "binary_f1", "policy_score"): cells.append(f"{v:>8.3f}") else: cells.append(f"{v:>8.2f}") print(" ".join(cells)) def print_comparison(before: dict, after: dict, baselines: Dict[str, dict]): print("\n" + "=" * 104) print("POST-PROCESSING OPTIMIZATION RESULTS (val set)") print("=" * 104) header = f"{'Method':<30}" + " ".join(f"{lbl:>8}" for _, lbl in METRICS_KEYS) print(header) print("-" * len(header)) print_row("LKAlert (baseline, raw)", before) print_row("LKAlert (optimized)", after, star=True) print("-" * len(header)) for name, m in baselines.items(): if "LKAlert" in name: continue print_row(name, m) print("=" * 104) def print_driver_metrics(before: dict, after: dict, baselines: Dict[str, dict]): print("\n" + "=" * 90) print("DRIVER-INTUITIVE METRICS (assume each sample = 1 s of driving)") print("=" * 90) cols = [ ("crashes_caught_per_100", "Crashes/100"), ("alert_lead_seconds", "AlertLead(s)"), ("alert_coverage_pct", "AlertCov(%)"), ("false_alerts_per_hour", "FA/hour"), ("non_ego_noise_per_hour", "Noise/hour"), ("silent_when_safe", "Silent(%)"), ] hdr = f"{'Method':<30}" + " ".join(f"{l:>12}" for _, l in cols) print(hdr); print("-" * len(hdr)) def _row(name, dm, star=False): label = ("★ " + name) if star else (" " + name) print(f"{label:<30}" + " ".join( f"{dm.get(k, 0.0):>12}" for k, _ in cols)) _row("LKAlert (baseline, raw)", driver_metrics(before)) _row("LKAlert (optimized)", driver_metrics(after), star=True) for name, m in baselines.items(): if "LKAlert" in name: continue dm = driver_metrics({ "ego_alert_recall": m.get("ego_alert_recall", 0), "safe_neg_alert_leak": m.get("safe_neg_alert_leak", 0), "burden_non_ego": 0.0, # not in baselines JSON directly "safe_neg_silent": m.get("safe_neg_silent", 0), "lead_time_mean": m.get("observe_lead_time_s", 0), "alert_lead_time_mean": m.get("observe_lead_time_s", 0), "lead_time_coverage": m.get("observe_coverage", 0), "alert_lead_time_coverage": m.get("observe_coverage", 0), }) _row(name, dm) print("=" * 90) # ────────────────────────────── Main ────────────────────────────────────────── def main(): P = argparse.ArgumentParser() P.add_argument("--sft_checkpoint", required=True) P.add_argument("--lkalert_ckpt", required=True) P.add_argument("--label_dir", default="data/policy_labels") P.add_argument("--belief_cache_dir", default="data/belief_cache") P.add_argument("--baseline_json", default="eval_results/paper_comparison/all_results.json") P.add_argument("--output_dir", default="eval_results/paper_comparison") P.add_argument("--recall_min", type=float, default=0.65, help="Minimum ego-alert recall when picking the best config") P.add_argument("--fa_max", type=float, default=0.30, help="Maximum allowed false-alarm rate on safe_neg samples " "(default 0.30 — comparable to strongest baseline)") P.add_argument("--burden_max", type=float, default=0.25, help="Maximum allowed non-ego alert burden " "(default 0.25 — keeps driver noise bounded)") P.add_argument("--force_extract", action="store_true", help="Re-extract logits even if cached") args = P.parse_args() out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) cache_dir = out_dir / "logits_cache" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ── 1. Extract logits ───────────────────────────────────────────────────── logger.info("── Step 1: extracting LKAlert logits (train + val) ──") train = load_or_extract(cache_dir, "train", args.sft_checkpoint, args.lkalert_ckpt, args.label_dir, args.belief_cache_dir, device, force=args.force_extract) val = load_or_extract(cache_dir, "val", args.sft_checkpoint, args.lkalert_ckpt, args.label_dir, args.belief_cache_dir, device, force=args.force_extract) # ── 2. Fit per-class T on two candidate calibration sets ───────────────── logger.info("── Step 2: fitting per-class temperatures ──") T_pc_train = fit_per_class_temperature(train["logits"], train["labels"]) logger.info(f" T(train-fit) = (SIL={T_pc_train[0]:.3f}, " f"OBS={T_pc_train[1]:.3f}, ALT={T_pc_train[2]:.3f})") # 20% stratified split of val → used ONLY for calibration fitting fit_idx, eval_idx = stratified_split_indices(val["labels"], frac=0.2) T_pc_valfit = fit_per_class_temperature( val["logits"][fit_idx], val["labels"][fit_idx]) logger.info(f" T(val20%-fit) = (SIL={T_pc_valfit[0]:.3f}, " f"OBS={T_pc_valfit[1]:.3f}, ALT={T_pc_valfit[2]:.3f}) " f"(fit={len(fit_idx)}, eval={len(eval_idx)})") # ── 3. Baseline (raw) metrics on FULL val ───────────────────────────────── raw = compute_metrics(val["logits"], val["labels"], val["categories"], val["ttas"], val["video_ids"]) logger.info(f" Raw val PolicyScore = {raw['policy_score']:.4f} " f"Recall = {raw['ego_alert_recall']:.3f} " f"FA = {raw['safe_neg_alert_leak']:.3f}") # ── 4. Grid search over calibration × smoothing × non_ego × bias ────────── # Eval only on the 80% held-out portion to avoid fit/eval leakage eval_view = { "logits": val["logits"][eval_idx], "labels": val["labels"][eval_idx], "categories": val["categories"][eval_idx], "ttas": val["ttas"][eval_idx], "video_ids": val["video_ids"][eval_idx], } calib_opts = [ {"name": "raw", "kind": "raw"}, {"name": "global_T=0.6", "kind": "global_T", "T": 0.6}, {"name": "global_T=0.8", "kind": "global_T", "T": 0.8}, {"name": "global_T=1.2", "kind": "global_T", "T": 1.2}, {"name": "global_T=1.5", "kind": "global_T", "T": 1.5}, {"name": "per_class_val20%", "kind": "per_class_T", "T": T_pc_valfit}, ] logger.info(f"── Step 4: grid search on 80% val ({len(eval_idx)} samples) ──") logger.info(f" Constraints: recall ≥ {args.recall_min}, " f"FA ≤ {args.fa_max}, burden ≤ {args.burden_max}") all_cfgs = grid_search(eval_view, calib_opts) recs_80 = pick_recommended(all_cfgs, args.recall_min, args.fa_max, args.burden_max, raw) if not recs_80: logger.warning("No configuration satisfies any constraint set; " "falling back to the closest match.") all_cfgs_sorted = sorted(all_cfgs, key=lambda r: -r["metrics"]["policy_score"]) recs_80 = {"balanced": all_cfgs_sorted[0]} # ── 4b. Re-evaluate each recommended config on FULL val ────────────────── name_to_calib = {c["name"]: c for c in calib_opts} recommended: Dict[str, dict] = {} for tag, r in recs_80.items(): cfg = r["cfg"] lg = _apply(val["logits"], val["video_ids"], val["ttas"], name_to_calib[cfg["calib"]], cfg["smooth_alpha"], cfg["non_ego_alpha"], cfg["alert_bias"]) m = compute_metrics(lg, val["labels"], val["categories"], val["ttas"], val["video_ids"]) recommended[tag] = {"cfg": cfg, "metrics": m} logger.info(f" [{tag}] cfg={cfg} → " f"Policy={m['policy_score']:.4f} Rec={m['ego_alert_recall']:.3f} " f"FA={m['safe_neg_alert_leak']:.3f} Bur={m['burden_non_ego']:.3f}") # Pick "balanced" if present else first key as the headline "optimized" result best_tag = "balanced" if "balanced" in recommended else next(iter(recommended)) best = recommended[best_tag] # `kept` retained for top-10 display (under recall only, sorted by PolicyScore) kept = [r for r in all_cfgs if r["metrics"]["ego_alert_recall"] >= args.recall_min] kept.sort(key=lambda r: -r["metrics"]["policy_score"]) # ── 5. Load baselines JSON for comparison ───────────────────────────────── baselines = {} bj = Path(args.baseline_json) if bj.exists(): baselines = json.loads(bj.read_text()) else: logger.warning(f"Baseline JSON {bj} not found; skipping cross-model comparison") # ── 6. Print comparison tables ──────────────────────────────────────────── # Three recommended operating points print("\n" + "=" * 104) print("RECOMMENDED OPERATING POINTS (re-evaluated on FULL val)") print("=" * 104) header = f"{'Variant':<22}" + " ".join(f"{lbl:>8}" for _, lbl in METRICS_KEYS) print(header); print("-" * len(header)) print_row("raw (current default)", raw) for tag in ("low_fa", "balanced", "high_recall"): if tag in recommended: print_row(f"opt: {tag}", recommended[tag]["metrics"], star=(tag == best_tag)) print("-" * len(header)) for name, m in baselines.items(): if "LKAlert" in name or "Random" in name: continue print_row(name, m) print("=" * 104) # Driver-intuitive metrics use the headline (best) config print_driver_metrics(raw, best["metrics"], baselines) # ── 6b. Top-10 candidates (on 80% eval split) ───────────────────────────── print("\n" + "=" * 100) print("TOP-10 CANDIDATE CONFIGS (ranked by PolicyScore on 80% val eval split)") print("=" * 100) print(f"{'calib':<20} {'sm_α':>6} {'ne_α':>6} {'bias':>6} " f"{'Policy↑':>8} {'Rec↑':>6} {'FA↓':>6} {'Bur↓':>6} {'AP↑':>6} {'F1↑':>6}") for r in kept[:10]: c = r["cfg"]; m = r["metrics"] print(f"{c['calib']:<20} {c['smooth_alpha']:>6.2f} {c['non_ego_alpha']:>6.2f} " f"{c['alert_bias']:>+6.2f} " f"{m['policy_score']:>8.4f} {m['ego_alert_recall']:>6.3f} " f"{m['safe_neg_alert_leak']:>6.3f} {m['burden_non_ego']:>6.3f} " f"{m['binary_ap']:>6.3f} {m['binary_f1']:>6.3f}") # ── 7. Save outputs ─────────────────────────────────────────────────────── def _jsonable(cfg): c = dict(cfg) return c # cfg fields are already JSON-safe result = { "constraints": {"recall_min": args.recall_min, "fa_max": args.fa_max, "burden_max": args.burden_max}, "raw_metrics": raw, "recommended": {tag: {"cfg": _jsonable(r["cfg"]), "metrics": r["metrics"]} for tag, r in recommended.items()}, "headline_variant": best_tag, "top10_by_policy_score": [{"cfg": r["cfg"], "policy_score": r["metrics"]["policy_score"], "recall": r["metrics"]["ego_alert_recall"], "fa": r["metrics"]["safe_neg_alert_leak"], "burden": r["metrics"]["burden_non_ego"]} for r in kept[:10]], } (out_dir / "postproc_best.json").write_text(json.dumps(result, indent=2)) logger.info(f"Saved → {out_dir / 'postproc_best.json'}") driver = { "LKAlert (baseline, raw)": driver_metrics(raw), "LKAlert (optimized)": driver_metrics(best["metrics"]), } for name, m in baselines.items(): if "LKAlert" in name: continue driver[name] = driver_metrics({ "ego_alert_recall": m.get("ego_alert_recall", 0), "safe_neg_alert_leak": m.get("safe_neg_alert_leak", 0), "burden_non_ego": 0.0, "safe_neg_silent": m.get("safe_neg_silent", 0), "lead_time_mean": m.get("observe_lead_time_s", 0), "alert_lead_time_mean": m.get("observe_lead_time_s", 0), "lead_time_coverage": m.get("observe_coverage", 0), "alert_lead_time_coverage": m.get("observe_coverage", 0), }) (out_dir / "driver_metrics.json").write_text(json.dumps(driver, indent=2)) logger.info(f"Saved → {out_dir / 'driver_metrics.json'}") if __name__ == "__main__": main()