| |
| """ |
| Optimize LKAlert post-processing hyper-parameters (per-class T, temporal |
| smoothing, non-ego bias, global ALERT bias) to maximize PolicyScore subject to |
| recall constraints β without any retraining. |
| |
| Pipeline |
| ββββββββ |
| 1. Extract raw LKAlert logits for train + val splits (cached to disk). |
| 2. Fit per-class temperature on train logits. |
| 3. Grid-search (smooth_alpha, non_ego_alpha, alert_bias) on val. |
| 4. Pick best operating point under constraint recall_ego β₯ R_min. |
| 5. Print a before/after table vs baselines (read from existing all_results.json). |
| 6. Emit driver-intuitive metrics (crashes caught / FA per hour / seconds saved). |
| 7. Save best config + metrics JSON. |
| |
| Outputs |
| βββββββ |
| eval_results/paper_comparison/postproc_best.json β chosen config + metrics |
| eval_results/paper_comparison/logits_cache/ β cached logits |
| eval_results/paper_comparison/driver_metrics.json β intuitive metrics |
| |
| Usage |
| βββββ |
| python -m training.Policy.optimize_postproc \ |
| --sft_checkpoint checkpoints/SFT/sft_v2/best \ |
| --lkalert_ckpt checkpoints/Policy/policy_warmstart_v3/best \ |
| --label_dir data/policy_labels \ |
| --belief_cache_dir data/belief_cache \ |
| --baseline_json eval_results/paper_comparison/all_results.json \ |
| --output_dir eval_results/paper_comparison \ |
| --recall_min 0.65 |
| """ |
| from __future__ import annotations |
| import argparse |
| import json |
| import logging |
| import sys |
| from itertools import product |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| ROOT = Path(__file__).resolve().parents[2] |
| sys.path.insert(0, str(ROOT)) |
|
|
| from training.Policy.policy_model import PolicyModel |
| from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn |
| from training.Policy.postproc import ( |
| fit_per_class_temperature, |
| apply_per_class_temperature, |
| temporal_smooth, |
| non_ego_bias as apply_non_ego_bias, |
| compute_metrics, |
| ) |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("Policy.optimize_postproc") |
|
|
|
|
| |
| @torch.no_grad() |
| def extract_logits( |
| sft_ckpt: Path, policy_ckpt: Path, |
| label_json: Path, cache_pt: Path, |
| split: str, device, |
| ) -> Dict[str, np.ndarray]: |
| """Run PolicyModel once; collect logits + meta into numpy arrays.""" |
| ds = PolicyDataset(manifests=[label_json], split=split, |
| belief_cache_path=cache_pt) |
| loader = DataLoader(ds, batch_size=512, shuffle=False, num_workers=4, |
| collate_fn=policy_collate_fn) |
|
|
| model = PolicyModel(str(sft_ckpt), use_bf16=True) |
| model.load_policy_checkpoint(str(policy_ckpt)) |
| model.eval() |
|
|
| L, lbl, cat, tta, vids = [], [], [], [], [] |
| for b in tqdm(loader, desc=f"extract[{split}]"): |
| lg = model.forward_cached( |
| b["beliefs"].to(device), |
| b["tta_means"].to(device), |
| b["tta_vars"].to(device), |
| ) |
| L.append(lg.cpu().numpy()) |
| lbl += b["action_labels"].tolist() |
| cat += b["categories"] |
| tta += b["tta_raws"].tolist() |
| vids += b["video_ids"] |
|
|
| del model |
| torch.cuda.empty_cache() |
| return { |
| "logits": np.concatenate(L).astype(np.float32), |
| "labels": np.array(lbl, dtype=np.int64), |
| "categories": np.array(cat), |
| "ttas": np.array(tta, dtype=np.float32), |
| "video_ids": np.array(vids), |
| } |
|
|
|
|
| def load_or_extract(cache_dir: Path, split: str, |
| sft_ckpt, policy_ckpt, label_dir, belief_dir, device, |
| force: bool = False) -> Dict[str, np.ndarray]: |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| cache_file = cache_dir / f"{split}_lkalert.npz" |
| if cache_file.exists() and not force: |
| logger.info(f"Loading cached logits from {cache_file}") |
| npz = np.load(cache_file, allow_pickle=True) |
| return {k: npz[k] for k in npz.files} |
|
|
| data = extract_logits( |
| Path(sft_ckpt), Path(policy_ckpt), |
| Path(label_dir) / f"{split}.json", |
| Path(belief_dir) / f"{split}.pt", |
| split, device, |
| ) |
| np.savez_compressed(cache_file, **data) |
| logger.info(f"Saved logits cache β {cache_file} (shape={data['logits'].shape})") |
| return data |
|
|
|
|
| |
| def stratified_split_indices(labels: np.ndarray, frac: float, seed: int = 0): |
| """Stratified index split: returns (idx_A, idx_B) where idx_A has `frac` of each class.""" |
| rng = np.random.default_rng(seed) |
| idx_A, idx_B = [], [] |
| for c in np.unique(labels): |
| pool = np.where(labels == c)[0] |
| rng.shuffle(pool) |
| k = int(round(frac * len(pool))) |
| idx_A.append(pool[:k]) |
| idx_B.append(pool[k:]) |
| return np.concatenate(idx_A), np.concatenate(idx_B) |
|
|
|
|
| def _apply(logits, video_ids, ttas, calib, smooth_alpha, non_ego_alpha, alert_bias): |
| """Apply one (calibration, smoothing, non_ego_bias, global_bias) combo.""" |
| lg = logits.astype(np.float32).copy() |
|
|
| kind = calib["kind"] |
| if kind == "global_T": |
| lg = lg / float(calib["T"]) |
| elif kind == "per_class_T": |
| lg = apply_per_class_temperature(lg, calib["T"]) |
| |
|
|
| if smooth_alpha < 1.0: |
| lg = temporal_smooth(lg, video_ids, ttas, |
| window=1, mode="ema", alpha=smooth_alpha) |
| if non_ego_alpha > 0.0: |
| lg = apply_non_ego_bias(lg, alpha=non_ego_alpha) |
| if alert_bias != 0.0: |
| lg[:, 2] = lg[:, 2] + alert_bias |
| return lg |
|
|
|
|
| def grid_search(val: Dict[str, np.ndarray], calib_opts: List[dict]) -> List[dict]: |
| """ |
| Sweep (calibration Γ smooth_alpha Γ non_ego_alpha Γ alert_bias). |
| Returns list of {cfg, metrics} dicts β no filtering applied here. |
| """ |
| smooth_alphas = [1.0, 0.8, 0.6, 0.4] |
| non_ego_alphas = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0] |
| biases = np.round(np.linspace(-1.0, 1.5, 26), 3).tolist() |
|
|
| total = len(calib_opts) * len(smooth_alphas) * len(non_ego_alphas) * len(biases) |
| logger.info(f" Grid size: {total} configurations") |
|
|
| results = [] |
| for calib in calib_opts: |
| for sa in smooth_alphas: |
| for na in non_ego_alphas: |
| for ab in biases: |
| lg = _apply(val["logits"], val["video_ids"], val["ttas"], |
| calib, sa, na, ab) |
| m = compute_metrics(lg, val["labels"], val["categories"], |
| val["ttas"], val["video_ids"]) |
| results.append({ |
| "cfg": {"calib": calib["name"], |
| "smooth_alpha": sa, |
| "non_ego_alpha": na, |
| "alert_bias": ab}, |
| "metrics": m, |
| }) |
| return results |
|
|
|
|
| def filter_configs(results: List[dict], |
| recall_min: float, fa_max: float, burden_max: float) -> List[dict]: |
| return [r for r in results |
| if r["metrics"]["ego_alert_recall"] >= recall_min |
| and r["metrics"]["safe_neg_alert_leak"] <= fa_max |
| and r["metrics"]["burden_non_ego"] <= burden_max] |
|
|
|
|
| def pick_recommended(all_cfgs: List[dict], |
| recall_min: float, fa_max: float, burden_max: float, |
| raw_metrics: dict) -> Dict[str, dict]: |
| """ |
| Return up to three recommended configs: |
| balanced : highest PolicyScore under all three constraints |
| low_fa : under recall_min, minimize FA (FA dominance vs baselines) |
| high_recall : under fa_max & burden_max, maximize recall (capture more crashes) |
| |
| Each is a (cfg, metrics) dict; fields may be None if no qualifier matches. |
| """ |
| out: Dict[str, dict] = {} |
|
|
| |
| kept = filter_configs(all_cfgs, recall_min, fa_max, burden_max) |
| if kept: |
| kept.sort(key=lambda r: -r["metrics"]["policy_score"]) |
| out["balanced"] = kept[0] |
|
|
| |
| fa_pool = [r for r in all_cfgs if r["metrics"]["ego_alert_recall"] >= recall_min] |
| if fa_pool: |
| fa_pool.sort(key=lambda r: (r["metrics"]["safe_neg_alert_leak"], |
| -r["metrics"]["ego_alert_recall"])) |
| out["low_fa"] = fa_pool[0] |
|
|
| |
| hr_pool = [r for r in all_cfgs |
| if r["metrics"]["safe_neg_alert_leak"] <= fa_max |
| and r["metrics"]["burden_non_ego"] <= burden_max] |
| if hr_pool: |
| hr_pool.sort(key=lambda r: (-r["metrics"]["ego_alert_recall"], |
| r["metrics"]["safe_neg_alert_leak"])) |
| out["high_recall"] = hr_pool[0] |
|
|
| return out |
|
|
|
|
| |
| def driver_metrics(m: dict, fps: float = 1.0, window_s: float = 1.0) -> dict: |
| """ |
| Translate raw metrics into phrases a reviewer/engineer can grasp. |
| Assumes each sample covers `window_s` seconds of driving (default 1 s). |
| |
| - crashes_caught_per_100 = 100 Γ ego_alert_recall |
| - false_alerts_per_hour = fa_rate Γ (3600 / window_s) |
| - non_ego_noise_per_hour = burden_non_ego Γ (3600 / window_s) |
| - silent_when_safe = safe_neg_silent (fraction) |
| - mean_lead_seconds = lead_time_mean (OBSERVEβͺALERT) |
| - alert_lead_seconds = alert_lead_time_mean (ALERT only) |
| - warning_coverage = lead_time_coverage (fraction of collision videos warned) |
| """ |
| per_hour = 3600.0 / window_s |
| return { |
| "crashes_caught_per_100": round(100.0 * m["ego_alert_recall"], 1), |
| "false_alerts_per_hour": round(m["safe_neg_alert_leak"] * per_hour, 0), |
| "non_ego_noise_per_hour": round(m["burden_non_ego"] * per_hour, 0), |
| "silent_when_safe": round(100.0 * m["safe_neg_silent"], 1), |
| "mean_lead_seconds": round(m["lead_time_mean"], 2), |
| "alert_lead_seconds": round(m["alert_lead_time_mean"], 2), |
| "warning_coverage_pct": round(100.0 * m["lead_time_coverage"], 1), |
| "alert_coverage_pct": round(100.0 * m["alert_lead_time_coverage"], 1), |
| } |
|
|
|
|
| |
| METRICS_KEYS = [ |
| ("policy_score", "PolicyScoreβ"), |
| ("ego_alert_recall", "EgoRecβ"), |
| ("safe_neg_alert_leak", "FAβ"), |
| ("burden_non_ego", "Burdenβ"), |
| ("binary_ap", "APβ"), |
| ("binary_f1", "F1β"), |
| ("lead_time_mean", "Lead(s)β"), |
| ("lead_time_coverage", "Covβ"), |
| ] |
|
|
|
|
| def print_row(name: str, m: dict, star: bool = False): |
| label = ("β
" + name) if star else (" " + name) |
| cells = [f"{label:<30}"] |
| for k, _ in METRICS_KEYS: |
| v = m.get(k, 0.0) |
| if "coverage" in k or "recall" in k or "noalert" in k \ |
| or k in ("safe_neg_alert_leak", "burden_non_ego", |
| "binary_ap", "binary_f1", "policy_score"): |
| cells.append(f"{v:>8.3f}") |
| else: |
| cells.append(f"{v:>8.2f}") |
| print(" ".join(cells)) |
|
|
|
|
| def print_comparison(before: dict, after: dict, baselines: Dict[str, dict]): |
| print("\n" + "=" * 104) |
| print("POST-PROCESSING OPTIMIZATION RESULTS (val set)") |
| print("=" * 104) |
| header = f"{'Method':<30}" + " ".join(f"{lbl:>8}" for _, lbl in METRICS_KEYS) |
| print(header) |
| print("-" * len(header)) |
| print_row("LKAlert (baseline, raw)", before) |
| print_row("LKAlert (optimized)", after, star=True) |
| print("-" * len(header)) |
| for name, m in baselines.items(): |
| if "LKAlert" in name: |
| continue |
| print_row(name, m) |
| print("=" * 104) |
|
|
|
|
| def print_driver_metrics(before: dict, after: dict, |
| baselines: Dict[str, dict]): |
| print("\n" + "=" * 90) |
| print("DRIVER-INTUITIVE METRICS (assume each sample = 1 s of driving)") |
| print("=" * 90) |
| cols = [ |
| ("crashes_caught_per_100", "Crashes/100"), |
| ("alert_lead_seconds", "AlertLead(s)"), |
| ("alert_coverage_pct", "AlertCov(%)"), |
| ("false_alerts_per_hour", "FA/hour"), |
| ("non_ego_noise_per_hour", "Noise/hour"), |
| ("silent_when_safe", "Silent(%)"), |
| ] |
| hdr = f"{'Method':<30}" + " ".join(f"{l:>12}" for _, l in cols) |
| print(hdr); print("-" * len(hdr)) |
|
|
| def _row(name, dm, star=False): |
| label = ("β
" + name) if star else (" " + name) |
| print(f"{label:<30}" + " ".join( |
| f"{dm.get(k, 0.0):>12}" for k, _ in cols)) |
|
|
| _row("LKAlert (baseline, raw)", driver_metrics(before)) |
| _row("LKAlert (optimized)", driver_metrics(after), star=True) |
| for name, m in baselines.items(): |
| if "LKAlert" in name: continue |
| dm = driver_metrics({ |
| "ego_alert_recall": m.get("ego_alert_recall", 0), |
| "safe_neg_alert_leak": m.get("safe_neg_alert_leak", 0), |
| "burden_non_ego": 0.0, |
| "safe_neg_silent": m.get("safe_neg_silent", 0), |
| "lead_time_mean": m.get("observe_lead_time_s", 0), |
| "alert_lead_time_mean": m.get("observe_lead_time_s", 0), |
| "lead_time_coverage": m.get("observe_coverage", 0), |
| "alert_lead_time_coverage": m.get("observe_coverage", 0), |
| }) |
| _row(name, dm) |
| print("=" * 90) |
|
|
|
|
| |
| def main(): |
| P = argparse.ArgumentParser() |
| P.add_argument("--sft_checkpoint", required=True) |
| P.add_argument("--lkalert_ckpt", required=True) |
| P.add_argument("--label_dir", default="data/policy_labels") |
| P.add_argument("--belief_cache_dir", default="data/belief_cache") |
| P.add_argument("--baseline_json", |
| default="eval_results/paper_comparison/all_results.json") |
| P.add_argument("--output_dir", |
| default="eval_results/paper_comparison") |
| P.add_argument("--recall_min", type=float, default=0.65, |
| help="Minimum ego-alert recall when picking the best config") |
| P.add_argument("--fa_max", type=float, default=0.30, |
| help="Maximum allowed false-alarm rate on safe_neg samples " |
| "(default 0.30 β comparable to strongest baseline)") |
| P.add_argument("--burden_max", type=float, default=0.25, |
| help="Maximum allowed non-ego alert burden " |
| "(default 0.25 β keeps driver noise bounded)") |
| P.add_argument("--force_extract", action="store_true", |
| help="Re-extract logits even if cached") |
| args = P.parse_args() |
|
|
| out_dir = Path(args.output_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| cache_dir = out_dir / "logits_cache" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| logger.info("ββ Step 1: extracting LKAlert logits (train + val) ββ") |
| train = load_or_extract(cache_dir, "train", |
| args.sft_checkpoint, args.lkalert_ckpt, |
| args.label_dir, args.belief_cache_dir, |
| device, force=args.force_extract) |
| val = load_or_extract(cache_dir, "val", |
| args.sft_checkpoint, args.lkalert_ckpt, |
| args.label_dir, args.belief_cache_dir, |
| device, force=args.force_extract) |
|
|
| |
| logger.info("ββ Step 2: fitting per-class temperatures ββ") |
| T_pc_train = fit_per_class_temperature(train["logits"], train["labels"]) |
| logger.info(f" T(train-fit) = (SIL={T_pc_train[0]:.3f}, " |
| f"OBS={T_pc_train[1]:.3f}, ALT={T_pc_train[2]:.3f})") |
|
|
| |
| fit_idx, eval_idx = stratified_split_indices(val["labels"], frac=0.2) |
| T_pc_valfit = fit_per_class_temperature( |
| val["logits"][fit_idx], val["labels"][fit_idx]) |
| logger.info(f" T(val20%-fit) = (SIL={T_pc_valfit[0]:.3f}, " |
| f"OBS={T_pc_valfit[1]:.3f}, ALT={T_pc_valfit[2]:.3f}) " |
| f"(fit={len(fit_idx)}, eval={len(eval_idx)})") |
|
|
| |
| raw = compute_metrics(val["logits"], val["labels"], val["categories"], |
| val["ttas"], val["video_ids"]) |
| logger.info(f" Raw val PolicyScore = {raw['policy_score']:.4f} " |
| f"Recall = {raw['ego_alert_recall']:.3f} " |
| f"FA = {raw['safe_neg_alert_leak']:.3f}") |
|
|
| |
| |
| eval_view = { |
| "logits": val["logits"][eval_idx], |
| "labels": val["labels"][eval_idx], |
| "categories": val["categories"][eval_idx], |
| "ttas": val["ttas"][eval_idx], |
| "video_ids": val["video_ids"][eval_idx], |
| } |
| calib_opts = [ |
| {"name": "raw", "kind": "raw"}, |
| {"name": "global_T=0.6", "kind": "global_T", "T": 0.6}, |
| {"name": "global_T=0.8", "kind": "global_T", "T": 0.8}, |
| {"name": "global_T=1.2", "kind": "global_T", "T": 1.2}, |
| {"name": "global_T=1.5", "kind": "global_T", "T": 1.5}, |
| {"name": "per_class_val20%", "kind": "per_class_T", "T": T_pc_valfit}, |
| ] |
| logger.info(f"ββ Step 4: grid search on 80% val ({len(eval_idx)} samples) ββ") |
| logger.info(f" Constraints: recall β₯ {args.recall_min}, " |
| f"FA β€ {args.fa_max}, burden β€ {args.burden_max}") |
| all_cfgs = grid_search(eval_view, calib_opts) |
|
|
| recs_80 = pick_recommended(all_cfgs, |
| args.recall_min, args.fa_max, args.burden_max, |
| raw) |
| if not recs_80: |
| logger.warning("No configuration satisfies any constraint set; " |
| "falling back to the closest match.") |
| all_cfgs_sorted = sorted(all_cfgs, |
| key=lambda r: -r["metrics"]["policy_score"]) |
| recs_80 = {"balanced": all_cfgs_sorted[0]} |
|
|
| |
| name_to_calib = {c["name"]: c for c in calib_opts} |
| recommended: Dict[str, dict] = {} |
| for tag, r in recs_80.items(): |
| cfg = r["cfg"] |
| lg = _apply(val["logits"], val["video_ids"], val["ttas"], |
| name_to_calib[cfg["calib"]], |
| cfg["smooth_alpha"], cfg["non_ego_alpha"], |
| cfg["alert_bias"]) |
| m = compute_metrics(lg, val["labels"], val["categories"], |
| val["ttas"], val["video_ids"]) |
| recommended[tag] = {"cfg": cfg, "metrics": m} |
| logger.info(f" [{tag}] cfg={cfg} β " |
| f"Policy={m['policy_score']:.4f} Rec={m['ego_alert_recall']:.3f} " |
| f"FA={m['safe_neg_alert_leak']:.3f} Bur={m['burden_non_ego']:.3f}") |
|
|
| |
| best_tag = "balanced" if "balanced" in recommended else next(iter(recommended)) |
| best = recommended[best_tag] |
| |
| kept = [r for r in all_cfgs if r["metrics"]["ego_alert_recall"] >= args.recall_min] |
| kept.sort(key=lambda r: -r["metrics"]["policy_score"]) |
|
|
| |
| baselines = {} |
| bj = Path(args.baseline_json) |
| if bj.exists(): |
| baselines = json.loads(bj.read_text()) |
| else: |
| logger.warning(f"Baseline JSON {bj} not found; skipping cross-model comparison") |
|
|
| |
| |
| print("\n" + "=" * 104) |
| print("RECOMMENDED OPERATING POINTS (re-evaluated on FULL val)") |
| print("=" * 104) |
| header = f"{'Variant':<22}" + " ".join(f"{lbl:>8}" for _, lbl in METRICS_KEYS) |
| print(header); print("-" * len(header)) |
| print_row("raw (current default)", raw) |
| for tag in ("low_fa", "balanced", "high_recall"): |
| if tag in recommended: |
| print_row(f"opt: {tag}", recommended[tag]["metrics"], |
| star=(tag == best_tag)) |
| print("-" * len(header)) |
| for name, m in baselines.items(): |
| if "LKAlert" in name or "Random" in name: |
| continue |
| print_row(name, m) |
| print("=" * 104) |
|
|
| |
| print_driver_metrics(raw, best["metrics"], baselines) |
|
|
| |
| print("\n" + "=" * 100) |
| print("TOP-10 CANDIDATE CONFIGS (ranked by PolicyScore on 80% val eval split)") |
| print("=" * 100) |
| print(f"{'calib':<20} {'sm_Ξ±':>6} {'ne_Ξ±':>6} {'bias':>6} " |
| f"{'Policyβ':>8} {'Recβ':>6} {'FAβ':>6} {'Burβ':>6} {'APβ':>6} {'F1β':>6}") |
| for r in kept[:10]: |
| c = r["cfg"]; m = r["metrics"] |
| print(f"{c['calib']:<20} {c['smooth_alpha']:>6.2f} {c['non_ego_alpha']:>6.2f} " |
| f"{c['alert_bias']:>+6.2f} " |
| f"{m['policy_score']:>8.4f} {m['ego_alert_recall']:>6.3f} " |
| f"{m['safe_neg_alert_leak']:>6.3f} {m['burden_non_ego']:>6.3f} " |
| f"{m['binary_ap']:>6.3f} {m['binary_f1']:>6.3f}") |
|
|
| |
| def _jsonable(cfg): |
| c = dict(cfg) |
| return c |
| result = { |
| "constraints": {"recall_min": args.recall_min, |
| "fa_max": args.fa_max, |
| "burden_max": args.burden_max}, |
| "raw_metrics": raw, |
| "recommended": {tag: {"cfg": _jsonable(r["cfg"]), "metrics": r["metrics"]} |
| for tag, r in recommended.items()}, |
| "headline_variant": best_tag, |
| "top10_by_policy_score": [{"cfg": r["cfg"], |
| "policy_score": r["metrics"]["policy_score"], |
| "recall": r["metrics"]["ego_alert_recall"], |
| "fa": r["metrics"]["safe_neg_alert_leak"], |
| "burden": r["metrics"]["burden_non_ego"]} |
| for r in kept[:10]], |
| } |
| (out_dir / "postproc_best.json").write_text(json.dumps(result, indent=2)) |
| logger.info(f"Saved β {out_dir / 'postproc_best.json'}") |
|
|
| driver = { |
| "LKAlert (baseline, raw)": driver_metrics(raw), |
| "LKAlert (optimized)": driver_metrics(best["metrics"]), |
| } |
| for name, m in baselines.items(): |
| if "LKAlert" in name: continue |
| driver[name] = driver_metrics({ |
| "ego_alert_recall": m.get("ego_alert_recall", 0), |
| "safe_neg_alert_leak": m.get("safe_neg_alert_leak", 0), |
| "burden_non_ego": 0.0, |
| "safe_neg_silent": m.get("safe_neg_silent", 0), |
| "lead_time_mean": m.get("observe_lead_time_s", 0), |
| "alert_lead_time_mean": m.get("observe_lead_time_s", 0), |
| "lead_time_coverage": m.get("observe_coverage", 0), |
| "alert_lead_time_coverage": m.get("observe_coverage", 0), |
| }) |
| (out_dir / "driver_metrics.json").write_text(json.dumps(driver, indent=2)) |
| logger.info(f"Saved β {out_dir / 'driver_metrics.json'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|