VLAlert / training /Policy /optimize_postproc.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
25.6 kB
#!/usr/bin/env python3
"""
Optimize LKAlert post-processing hyper-parameters (per-class T, temporal
smoothing, non-ego bias, global ALERT bias) to maximize PolicyScore subject to
recall constraints β€” without any retraining.
Pipeline
────────
1. Extract raw LKAlert logits for train + val splits (cached to disk).
2. Fit per-class temperature on train logits.
3. Grid-search (smooth_alpha, non_ego_alpha, alert_bias) on val.
4. Pick best operating point under constraint recall_ego β‰₯ R_min.
5. Print a before/after table vs baselines (read from existing all_results.json).
6. Emit driver-intuitive metrics (crashes caught / FA per hour / seconds saved).
7. Save best config + metrics JSON.
Outputs
───────
eval_results/paper_comparison/postproc_best.json β€” chosen config + metrics
eval_results/paper_comparison/logits_cache/ β€” cached logits
eval_results/paper_comparison/driver_metrics.json β€” intuitive metrics
Usage
─────
python -m training.Policy.optimize_postproc \
--sft_checkpoint checkpoints/SFT/sft_v2/best \
--lkalert_ckpt checkpoints/Policy/policy_warmstart_v3/best \
--label_dir data/policy_labels \
--belief_cache_dir data/belief_cache \
--baseline_json eval_results/paper_comparison/all_results.json \
--output_dir eval_results/paper_comparison \
--recall_min 0.65
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from itertools import product
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
ROOT = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(ROOT))
from training.Policy.policy_model import PolicyModel
from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn
from training.Policy.postproc import (
fit_per_class_temperature,
apply_per_class_temperature,
temporal_smooth,
non_ego_bias as apply_non_ego_bias,
compute_metrics,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("Policy.optimize_postproc")
# ────────────────────────────── Logits extraction ─────────────────────────────
@torch.no_grad()
def extract_logits(
sft_ckpt: Path, policy_ckpt: Path,
label_json: Path, cache_pt: Path,
split: str, device,
) -> Dict[str, np.ndarray]:
"""Run PolicyModel once; collect logits + meta into numpy arrays."""
ds = PolicyDataset(manifests=[label_json], split=split,
belief_cache_path=cache_pt)
loader = DataLoader(ds, batch_size=512, shuffle=False, num_workers=4,
collate_fn=policy_collate_fn)
model = PolicyModel(str(sft_ckpt), use_bf16=True)
model.load_policy_checkpoint(str(policy_ckpt))
model.eval()
L, lbl, cat, tta, vids = [], [], [], [], []
for b in tqdm(loader, desc=f"extract[{split}]"):
lg = model.forward_cached(
b["beliefs"].to(device),
b["tta_means"].to(device),
b["tta_vars"].to(device),
)
L.append(lg.cpu().numpy())
lbl += b["action_labels"].tolist()
cat += b["categories"]
tta += b["tta_raws"].tolist()
vids += b["video_ids"]
del model
torch.cuda.empty_cache()
return {
"logits": np.concatenate(L).astype(np.float32),
"labels": np.array(lbl, dtype=np.int64),
"categories": np.array(cat),
"ttas": np.array(tta, dtype=np.float32),
"video_ids": np.array(vids),
}
def load_or_extract(cache_dir: Path, split: str,
sft_ckpt, policy_ckpt, label_dir, belief_dir, device,
force: bool = False) -> Dict[str, np.ndarray]:
cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / f"{split}_lkalert.npz"
if cache_file.exists() and not force:
logger.info(f"Loading cached logits from {cache_file}")
npz = np.load(cache_file, allow_pickle=True)
return {k: npz[k] for k in npz.files}
data = extract_logits(
Path(sft_ckpt), Path(policy_ckpt),
Path(label_dir) / f"{split}.json",
Path(belief_dir) / f"{split}.pt",
split, device,
)
np.savez_compressed(cache_file, **data)
logger.info(f"Saved logits cache β†’ {cache_file} (shape={data['logits'].shape})")
return data
# ────────────────────────────── Grid search ───────────────────────────────────
def stratified_split_indices(labels: np.ndarray, frac: float, seed: int = 0):
"""Stratified index split: returns (idx_A, idx_B) where idx_A has `frac` of each class."""
rng = np.random.default_rng(seed)
idx_A, idx_B = [], []
for c in np.unique(labels):
pool = np.where(labels == c)[0]
rng.shuffle(pool)
k = int(round(frac * len(pool)))
idx_A.append(pool[:k])
idx_B.append(pool[k:])
return np.concatenate(idx_A), np.concatenate(idx_B)
def _apply(logits, video_ids, ttas, calib, smooth_alpha, non_ego_alpha, alert_bias):
"""Apply one (calibration, smoothing, non_ego_bias, global_bias) combo."""
lg = logits.astype(np.float32).copy()
kind = calib["kind"]
if kind == "global_T":
lg = lg / float(calib["T"])
elif kind == "per_class_T":
lg = apply_per_class_temperature(lg, calib["T"])
# kind == "raw": no-op
if smooth_alpha < 1.0: # only when EMA actually mixes past frames
lg = temporal_smooth(lg, video_ids, ttas,
window=1, mode="ema", alpha=smooth_alpha)
if non_ego_alpha > 0.0:
lg = apply_non_ego_bias(lg, alpha=non_ego_alpha)
if alert_bias != 0.0:
lg[:, 2] = lg[:, 2] + alert_bias
return lg
def grid_search(val: Dict[str, np.ndarray], calib_opts: List[dict]) -> List[dict]:
"""
Sweep (calibration Γ— smooth_alpha Γ— non_ego_alpha Γ— alert_bias).
Returns list of {cfg, metrics} dicts β€” no filtering applied here.
"""
smooth_alphas = [1.0, 0.8, 0.6, 0.4] # 1.0 = no smoothing
non_ego_alphas = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0]
biases = np.round(np.linspace(-1.0, 1.5, 26), 3).tolist()
total = len(calib_opts) * len(smooth_alphas) * len(non_ego_alphas) * len(biases)
logger.info(f" Grid size: {total} configurations")
results = []
for calib in calib_opts:
for sa in smooth_alphas:
for na in non_ego_alphas:
for ab in biases:
lg = _apply(val["logits"], val["video_ids"], val["ttas"],
calib, sa, na, ab)
m = compute_metrics(lg, val["labels"], val["categories"],
val["ttas"], val["video_ids"])
results.append({
"cfg": {"calib": calib["name"],
"smooth_alpha": sa,
"non_ego_alpha": na,
"alert_bias": ab},
"metrics": m,
})
return results
def filter_configs(results: List[dict],
recall_min: float, fa_max: float, burden_max: float) -> List[dict]:
return [r for r in results
if r["metrics"]["ego_alert_recall"] >= recall_min
and r["metrics"]["safe_neg_alert_leak"] <= fa_max
and r["metrics"]["burden_non_ego"] <= burden_max]
def pick_recommended(all_cfgs: List[dict],
recall_min: float, fa_max: float, burden_max: float,
raw_metrics: dict) -> Dict[str, dict]:
"""
Return up to three recommended configs:
balanced : highest PolicyScore under all three constraints
low_fa : under recall_min, minimize FA (FA dominance vs baselines)
high_recall : under fa_max & burden_max, maximize recall (capture more crashes)
Each is a (cfg, metrics) dict; fields may be None if no qualifier matches.
"""
out: Dict[str, dict] = {}
# balanced β€” all constraints + max PolicyScore
kept = filter_configs(all_cfgs, recall_min, fa_max, burden_max)
if kept:
kept.sort(key=lambda r: -r["metrics"]["policy_score"])
out["balanced"] = kept[0]
# low_fa β€” just recall constraint, pick lowest FA, tiebreak by recall
fa_pool = [r for r in all_cfgs if r["metrics"]["ego_alert_recall"] >= recall_min]
if fa_pool:
fa_pool.sort(key=lambda r: (r["metrics"]["safe_neg_alert_leak"],
-r["metrics"]["ego_alert_recall"]))
out["low_fa"] = fa_pool[0]
# high_recall β€” FA & burden bounded, pick max recall
hr_pool = [r for r in all_cfgs
if r["metrics"]["safe_neg_alert_leak"] <= fa_max
and r["metrics"]["burden_non_ego"] <= burden_max]
if hr_pool:
hr_pool.sort(key=lambda r: (-r["metrics"]["ego_alert_recall"],
r["metrics"]["safe_neg_alert_leak"]))
out["high_recall"] = hr_pool[0]
return out
# ────────────────────────────── Driver-intuitive metrics ──────────────────────
def driver_metrics(m: dict, fps: float = 1.0, window_s: float = 1.0) -> dict:
"""
Translate raw metrics into phrases a reviewer/engineer can grasp.
Assumes each sample covers `window_s` seconds of driving (default 1 s).
- crashes_caught_per_100 = 100 Γ— ego_alert_recall
- false_alerts_per_hour = fa_rate Γ— (3600 / window_s)
- non_ego_noise_per_hour = burden_non_ego Γ— (3600 / window_s)
- silent_when_safe = safe_neg_silent (fraction)
- mean_lead_seconds = lead_time_mean (OBSERVEβˆͺALERT)
- alert_lead_seconds = alert_lead_time_mean (ALERT only)
- warning_coverage = lead_time_coverage (fraction of collision videos warned)
"""
per_hour = 3600.0 / window_s
return {
"crashes_caught_per_100": round(100.0 * m["ego_alert_recall"], 1),
"false_alerts_per_hour": round(m["safe_neg_alert_leak"] * per_hour, 0),
"non_ego_noise_per_hour": round(m["burden_non_ego"] * per_hour, 0),
"silent_when_safe": round(100.0 * m["safe_neg_silent"], 1),
"mean_lead_seconds": round(m["lead_time_mean"], 2),
"alert_lead_seconds": round(m["alert_lead_time_mean"], 2),
"warning_coverage_pct": round(100.0 * m["lead_time_coverage"], 1),
"alert_coverage_pct": round(100.0 * m["alert_lead_time_coverage"], 1),
}
# ────────────────────────────── Reports ───────────────────────────────────────
METRICS_KEYS = [
("policy_score", "PolicyScore↑"),
("ego_alert_recall", "EgoRec↑"),
("safe_neg_alert_leak", "FA↓"),
("burden_non_ego", "Burden↓"),
("binary_ap", "AP↑"),
("binary_f1", "F1↑"),
("lead_time_mean", "Lead(s)↑"),
("lead_time_coverage", "Cov↑"),
]
def print_row(name: str, m: dict, star: bool = False):
label = ("β˜… " + name) if star else (" " + name)
cells = [f"{label:<30}"]
for k, _ in METRICS_KEYS:
v = m.get(k, 0.0)
if "coverage" in k or "recall" in k or "noalert" in k \
or k in ("safe_neg_alert_leak", "burden_non_ego",
"binary_ap", "binary_f1", "policy_score"):
cells.append(f"{v:>8.3f}")
else:
cells.append(f"{v:>8.2f}")
print(" ".join(cells))
def print_comparison(before: dict, after: dict, baselines: Dict[str, dict]):
print("\n" + "=" * 104)
print("POST-PROCESSING OPTIMIZATION RESULTS (val set)")
print("=" * 104)
header = f"{'Method':<30}" + " ".join(f"{lbl:>8}" for _, lbl in METRICS_KEYS)
print(header)
print("-" * len(header))
print_row("LKAlert (baseline, raw)", before)
print_row("LKAlert (optimized)", after, star=True)
print("-" * len(header))
for name, m in baselines.items():
if "LKAlert" in name:
continue
print_row(name, m)
print("=" * 104)
def print_driver_metrics(before: dict, after: dict,
baselines: Dict[str, dict]):
print("\n" + "=" * 90)
print("DRIVER-INTUITIVE METRICS (assume each sample = 1 s of driving)")
print("=" * 90)
cols = [
("crashes_caught_per_100", "Crashes/100"),
("alert_lead_seconds", "AlertLead(s)"),
("alert_coverage_pct", "AlertCov(%)"),
("false_alerts_per_hour", "FA/hour"),
("non_ego_noise_per_hour", "Noise/hour"),
("silent_when_safe", "Silent(%)"),
]
hdr = f"{'Method':<30}" + " ".join(f"{l:>12}" for _, l in cols)
print(hdr); print("-" * len(hdr))
def _row(name, dm, star=False):
label = ("β˜… " + name) if star else (" " + name)
print(f"{label:<30}" + " ".join(
f"{dm.get(k, 0.0):>12}" for k, _ in cols))
_row("LKAlert (baseline, raw)", driver_metrics(before))
_row("LKAlert (optimized)", driver_metrics(after), star=True)
for name, m in baselines.items():
if "LKAlert" in name: continue
dm = driver_metrics({
"ego_alert_recall": m.get("ego_alert_recall", 0),
"safe_neg_alert_leak": m.get("safe_neg_alert_leak", 0),
"burden_non_ego": 0.0, # not in baselines JSON directly
"safe_neg_silent": m.get("safe_neg_silent", 0),
"lead_time_mean": m.get("observe_lead_time_s", 0),
"alert_lead_time_mean": m.get("observe_lead_time_s", 0),
"lead_time_coverage": m.get("observe_coverage", 0),
"alert_lead_time_coverage": m.get("observe_coverage", 0),
})
_row(name, dm)
print("=" * 90)
# ────────────────────────────── Main ──────────────────────────────────────────
def main():
P = argparse.ArgumentParser()
P.add_argument("--sft_checkpoint", required=True)
P.add_argument("--lkalert_ckpt", required=True)
P.add_argument("--label_dir", default="data/policy_labels")
P.add_argument("--belief_cache_dir", default="data/belief_cache")
P.add_argument("--baseline_json",
default="eval_results/paper_comparison/all_results.json")
P.add_argument("--output_dir",
default="eval_results/paper_comparison")
P.add_argument("--recall_min", type=float, default=0.65,
help="Minimum ego-alert recall when picking the best config")
P.add_argument("--fa_max", type=float, default=0.30,
help="Maximum allowed false-alarm rate on safe_neg samples "
"(default 0.30 β€” comparable to strongest baseline)")
P.add_argument("--burden_max", type=float, default=0.25,
help="Maximum allowed non-ego alert burden "
"(default 0.25 β€” keeps driver noise bounded)")
P.add_argument("--force_extract", action="store_true",
help="Re-extract logits even if cached")
args = P.parse_args()
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
cache_dir = out_dir / "logits_cache"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── 1. Extract logits ─────────────────────────────────────────────────────
logger.info("── Step 1: extracting LKAlert logits (train + val) ──")
train = load_or_extract(cache_dir, "train",
args.sft_checkpoint, args.lkalert_ckpt,
args.label_dir, args.belief_cache_dir,
device, force=args.force_extract)
val = load_or_extract(cache_dir, "val",
args.sft_checkpoint, args.lkalert_ckpt,
args.label_dir, args.belief_cache_dir,
device, force=args.force_extract)
# ── 2. Fit per-class T on two candidate calibration sets ─────────────────
logger.info("── Step 2: fitting per-class temperatures ──")
T_pc_train = fit_per_class_temperature(train["logits"], train["labels"])
logger.info(f" T(train-fit) = (SIL={T_pc_train[0]:.3f}, "
f"OBS={T_pc_train[1]:.3f}, ALT={T_pc_train[2]:.3f})")
# 20% stratified split of val β†’ used ONLY for calibration fitting
fit_idx, eval_idx = stratified_split_indices(val["labels"], frac=0.2)
T_pc_valfit = fit_per_class_temperature(
val["logits"][fit_idx], val["labels"][fit_idx])
logger.info(f" T(val20%-fit) = (SIL={T_pc_valfit[0]:.3f}, "
f"OBS={T_pc_valfit[1]:.3f}, ALT={T_pc_valfit[2]:.3f}) "
f"(fit={len(fit_idx)}, eval={len(eval_idx)})")
# ── 3. Baseline (raw) metrics on FULL val ─────────────────────────────────
raw = compute_metrics(val["logits"], val["labels"], val["categories"],
val["ttas"], val["video_ids"])
logger.info(f" Raw val PolicyScore = {raw['policy_score']:.4f} "
f"Recall = {raw['ego_alert_recall']:.3f} "
f"FA = {raw['safe_neg_alert_leak']:.3f}")
# ── 4. Grid search over calibration Γ— smoothing Γ— non_ego Γ— bias ──────────
# Eval only on the 80% held-out portion to avoid fit/eval leakage
eval_view = {
"logits": val["logits"][eval_idx],
"labels": val["labels"][eval_idx],
"categories": val["categories"][eval_idx],
"ttas": val["ttas"][eval_idx],
"video_ids": val["video_ids"][eval_idx],
}
calib_opts = [
{"name": "raw", "kind": "raw"},
{"name": "global_T=0.6", "kind": "global_T", "T": 0.6},
{"name": "global_T=0.8", "kind": "global_T", "T": 0.8},
{"name": "global_T=1.2", "kind": "global_T", "T": 1.2},
{"name": "global_T=1.5", "kind": "global_T", "T": 1.5},
{"name": "per_class_val20%", "kind": "per_class_T", "T": T_pc_valfit},
]
logger.info(f"── Step 4: grid search on 80% val ({len(eval_idx)} samples) ──")
logger.info(f" Constraints: recall β‰₯ {args.recall_min}, "
f"FA ≀ {args.fa_max}, burden ≀ {args.burden_max}")
all_cfgs = grid_search(eval_view, calib_opts)
recs_80 = pick_recommended(all_cfgs,
args.recall_min, args.fa_max, args.burden_max,
raw)
if not recs_80:
logger.warning("No configuration satisfies any constraint set; "
"falling back to the closest match.")
all_cfgs_sorted = sorted(all_cfgs,
key=lambda r: -r["metrics"]["policy_score"])
recs_80 = {"balanced": all_cfgs_sorted[0]}
# ── 4b. Re-evaluate each recommended config on FULL val ──────────────────
name_to_calib = {c["name"]: c for c in calib_opts}
recommended: Dict[str, dict] = {}
for tag, r in recs_80.items():
cfg = r["cfg"]
lg = _apply(val["logits"], val["video_ids"], val["ttas"],
name_to_calib[cfg["calib"]],
cfg["smooth_alpha"], cfg["non_ego_alpha"],
cfg["alert_bias"])
m = compute_metrics(lg, val["labels"], val["categories"],
val["ttas"], val["video_ids"])
recommended[tag] = {"cfg": cfg, "metrics": m}
logger.info(f" [{tag}] cfg={cfg} β†’ "
f"Policy={m['policy_score']:.4f} Rec={m['ego_alert_recall']:.3f} "
f"FA={m['safe_neg_alert_leak']:.3f} Bur={m['burden_non_ego']:.3f}")
# Pick "balanced" if present else first key as the headline "optimized" result
best_tag = "balanced" if "balanced" in recommended else next(iter(recommended))
best = recommended[best_tag]
# `kept` retained for top-10 display (under recall only, sorted by PolicyScore)
kept = [r for r in all_cfgs if r["metrics"]["ego_alert_recall"] >= args.recall_min]
kept.sort(key=lambda r: -r["metrics"]["policy_score"])
# ── 5. Load baselines JSON for comparison ─────────────────────────────────
baselines = {}
bj = Path(args.baseline_json)
if bj.exists():
baselines = json.loads(bj.read_text())
else:
logger.warning(f"Baseline JSON {bj} not found; skipping cross-model comparison")
# ── 6. Print comparison tables ────────────────────────────────────────────
# Three recommended operating points
print("\n" + "=" * 104)
print("RECOMMENDED OPERATING POINTS (re-evaluated on FULL val)")
print("=" * 104)
header = f"{'Variant':<22}" + " ".join(f"{lbl:>8}" for _, lbl in METRICS_KEYS)
print(header); print("-" * len(header))
print_row("raw (current default)", raw)
for tag in ("low_fa", "balanced", "high_recall"):
if tag in recommended:
print_row(f"opt: {tag}", recommended[tag]["metrics"],
star=(tag == best_tag))
print("-" * len(header))
for name, m in baselines.items():
if "LKAlert" in name or "Random" in name:
continue
print_row(name, m)
print("=" * 104)
# Driver-intuitive metrics use the headline (best) config
print_driver_metrics(raw, best["metrics"], baselines)
# ── 6b. Top-10 candidates (on 80% eval split) ─────────────────────────────
print("\n" + "=" * 100)
print("TOP-10 CANDIDATE CONFIGS (ranked by PolicyScore on 80% val eval split)")
print("=" * 100)
print(f"{'calib':<20} {'sm_Ξ±':>6} {'ne_Ξ±':>6} {'bias':>6} "
f"{'Policy↑':>8} {'Rec↑':>6} {'FA↓':>6} {'Bur↓':>6} {'AP↑':>6} {'F1↑':>6}")
for r in kept[:10]:
c = r["cfg"]; m = r["metrics"]
print(f"{c['calib']:<20} {c['smooth_alpha']:>6.2f} {c['non_ego_alpha']:>6.2f} "
f"{c['alert_bias']:>+6.2f} "
f"{m['policy_score']:>8.4f} {m['ego_alert_recall']:>6.3f} "
f"{m['safe_neg_alert_leak']:>6.3f} {m['burden_non_ego']:>6.3f} "
f"{m['binary_ap']:>6.3f} {m['binary_f1']:>6.3f}")
# ── 7. Save outputs ───────────────────────────────────────────────────────
def _jsonable(cfg):
c = dict(cfg)
return c # cfg fields are already JSON-safe
result = {
"constraints": {"recall_min": args.recall_min,
"fa_max": args.fa_max,
"burden_max": args.burden_max},
"raw_metrics": raw,
"recommended": {tag: {"cfg": _jsonable(r["cfg"]), "metrics": r["metrics"]}
for tag, r in recommended.items()},
"headline_variant": best_tag,
"top10_by_policy_score": [{"cfg": r["cfg"],
"policy_score": r["metrics"]["policy_score"],
"recall": r["metrics"]["ego_alert_recall"],
"fa": r["metrics"]["safe_neg_alert_leak"],
"burden": r["metrics"]["burden_non_ego"]}
for r in kept[:10]],
}
(out_dir / "postproc_best.json").write_text(json.dumps(result, indent=2))
logger.info(f"Saved β†’ {out_dir / 'postproc_best.json'}")
driver = {
"LKAlert (baseline, raw)": driver_metrics(raw),
"LKAlert (optimized)": driver_metrics(best["metrics"]),
}
for name, m in baselines.items():
if "LKAlert" in name: continue
driver[name] = driver_metrics({
"ego_alert_recall": m.get("ego_alert_recall", 0),
"safe_neg_alert_leak": m.get("safe_neg_alert_leak", 0),
"burden_non_ego": 0.0,
"safe_neg_silent": m.get("safe_neg_silent", 0),
"lead_time_mean": m.get("observe_lead_time_s", 0),
"alert_lead_time_mean": m.get("observe_lead_time_s", 0),
"lead_time_coverage": m.get("observe_coverage", 0),
"alert_lead_time_coverage": m.get("observe_coverage", 0),
})
(out_dir / "driver_metrics.json").write_text(json.dumps(driver, indent=2))
logger.info(f"Saved β†’ {out_dir / 'driver_metrics.json'}")
if __name__ == "__main__":
main()