#!/usr/bin/env python3 """ Post-processing transforms for PolicyModel logits. Three composable operations (all numpy-based, CPU-cheap): 1. Per-class temperature scaling — L-BFGS-fit on train split (3 params) 2. Temporal smoothing — EMA over per-video sorted-by-TTA sequence 3. Non-ego bias — reduces ALERT when OBSERVE probability dominates Applied in order: raw_logits → per-class-T → smoothing → non-ego-bias → (+global_bias) """ from __future__ import annotations from collections import defaultdict from typing import Sequence, Tuple import numpy as np import torch import torch.nn.functional as F # ─────────────────────────── 1. Per-class temperature ───────────────────────── def fit_per_class_temperature( logits: np.ndarray, labels: np.ndarray, init: Tuple[float, float, float] = (1.0, 1.0, 1.0), max_iter: int = 200, device: str = "cpu", ) -> np.ndarray: """ Fit 3 positive temperatures T = (T_SILENT, T_OBSERVE, T_ALERT) by minimizing cross-entropy on (logits / T). logits : [N, 3] float labels : [N] int in {0, 1, 2} Returns T : [3] float (positive) """ lg = torch.as_tensor(logits, dtype=torch.float32, device=device) lb = torch.as_tensor(labels, dtype=torch.long, device=device) log_T = torch.tensor(np.log(init), dtype=torch.float32, device=device, requires_grad=True) opt = torch.optim.LBFGS([log_T], lr=0.1, max_iter=max_iter, tolerance_grad=1e-7, line_search_fn="strong_wolfe") def closure(): opt.zero_grad() T = log_T.exp().clamp(min=1e-3, max=1e3) cal = lg / T.unsqueeze(0) # broadcast [1,3] loss = F.cross_entropy(cal, lb) loss.backward() return loss opt.step(closure) T = log_T.detach().exp().clamp(min=0.1, max=10.0).cpu().numpy() return T.astype(np.float32) def apply_per_class_temperature(logits: np.ndarray, T: Sequence[float]) -> np.ndarray: """logits [N,3] ÷ T [3] (broadcast).""" return (logits / np.asarray(T, dtype=np.float32)[None, :]).astype(np.float32) # ─────────────────────────── 2. Temporal smoothing ──────────────────────────── def temporal_smooth( logits: np.ndarray, video_ids: Sequence[str], ttas: np.ndarray, window: int = 3, mode: str = "ema", # "ema" | "mean" alpha: float = 0.5, # EMA weight on current frame ) -> np.ndarray: """ For each video, sort its samples by tta DESC (earliest = largest tta first) and apply causal smoothing. Samples with tta == -1 (non_ego / safe_neg) have no meaningful temporal order → kept as-is. logits : [N, 3] video_ids: list[str] length N ttas : [N] float; negative values ⇒ no smoothing window : for mode="mean", the sliding window size alpha : for mode="ema", weight of current frame (past gets 1-alpha) """ out = logits.copy() by_vid: dict[str, list[int]] = defaultdict(list) for i, v in enumerate(video_ids): if ttas[i] >= 0: # only ego-positive samples have a time axis by_vid[v].append(i) for v, idxs in by_vid.items(): if len(idxs) < 2: continue # largest tta first = earliest in time idxs_sorted = sorted(idxs, key=lambda i: -float(ttas[i])) L = np.stack([logits[i] for i in idxs_sorted]) # [n, 3] if mode == "mean": sm = L.copy() for pos in range(1, len(L)): lo = max(0, pos - window + 1) sm[pos] = L[lo:pos + 1].mean(axis=0) elif mode == "ema": sm = np.empty_like(L) sm[0] = L[0] for pos in range(1, len(L)): sm[pos] = alpha * L[pos] + (1.0 - alpha) * sm[pos - 1] else: raise ValueError(f"Unknown mode: {mode}") for k, i in enumerate(idxs_sorted): out[i] = sm[k] return out # ─────────────────────────── 3. Non-ego bias ────────────────────────────────── def non_ego_bias(logits: np.ndarray, alpha: float = 0.5) -> np.ndarray: """ Reduce ALERT logit when OBSERVE probability dominates. Uses only information available at inference (no ground-truth leakage). Mechanism: p = softmax(logits) logits[:, 2] -= alpha * p[:, 1] alpha in [0, 2] is the magnitude of the nudge. """ if alpha == 0.0: return logits p = np.exp(logits - logits.max(axis=1, keepdims=True)) p /= p.sum(axis=1, keepdims=True) out = logits.copy() out[:, 2] = out[:, 2] - alpha * p[:, 1] return out # ─────────────────────────── All-in-one helper ──────────────────────────────── def apply_postproc( logits: np.ndarray, video_ids: Sequence[str], ttas: np.ndarray, T_per_class: Sequence[float] | None = None, smooth_window: int = 1, # 1 ⇒ no smoothing smooth_mode: str = "ema", smooth_alpha: float = 0.5, non_ego_alpha: float = 0.0, alert_bias: float = 0.0, ) -> np.ndarray: lg = logits.astype(np.float32).copy() if T_per_class is not None: lg = apply_per_class_temperature(lg, T_per_class) if smooth_window > 1 or smooth_mode == "ema": lg = temporal_smooth(lg, video_ids, ttas, window=smooth_window, mode=smooth_mode, alpha=smooth_alpha) if non_ego_alpha > 0.0: lg = non_ego_bias(lg, alpha=non_ego_alpha) if alert_bias != 0.0: lg[:, 2] = lg[:, 2] + alert_bias return lg # ─────────────────────────── Metric computation ─────────────────────────────── def compute_metrics( logits: np.ndarray, # [N, 3] labels: np.ndarray, # [N] in {0,1,2} cats: np.ndarray, # [N] str ttas: np.ndarray, # [N] float video_ids: Sequence[str], ) -> dict: """Mirror of baseline.comparison.eval_all.compute_all_metrics.""" preds = logits.argmax(axis=1) # softmax only for prob_alert (for binary AP) p = np.exp(logits - logits.max(axis=1, keepdims=True)) p /= p.sum(axis=1, keepdims=True) prob_alert = p[:, 2] def _r(n, d): return float(n) / float(d) if d > 0 else 0.0 ego_mask = cats == "ego_positive" safe_mask = cats == "safe_neg" ne_mask = cats == "non_ego" ego_alert = _r(((preds == 2) & ego_mask & (labels == 2)).sum(), (ego_mask & (labels == 2)).sum()) non_ego_noalert = _r(((preds != 2) & ne_mask).sum(), ne_mask.sum()) safe_silent = _r(((preds == 0) & safe_mask).sum(), safe_mask.sum()) fa = _r(((preds == 2) & safe_mask).sum(), safe_mask.sum()) burden_non_ego = _r(((preds == 2) & ne_mask).sum(), ne_mask.sum()) # PolicyScore v3 (safety-first): 0.65*ego_recall + 0.25*safe_silent - 0.15*safe_alert policy_score = 0.65 * ego_alert + 0.25 * safe_silent - 0.15 * fa # binary AP / F1 from sklearn.metrics import average_precision_score binary_true = (labels == 2).astype(int) try: ap = float(average_precision_score(binary_true, prob_alert)) except Exception: ap = 0.0 tp = int(((preds == 2) & (labels == 2)).sum()) fp = int(((preds == 2) & (labels != 2)).sum()) fn = int(((preds != 2) & (labels == 2)).sum()) prec = _r(tp, tp + fp) rec = _r(tp, tp + fn) f1 = _r(2 * prec * rec, prec + rec) # lead time (OBSERVE∪ALERT response, per ego video) ego_idx = np.where(ego_mask)[0] by_video: dict[str, list] = defaultdict(list) for i in ego_idx: by_video[video_ids[i]].append((float(ttas[i]), int(preds[i]))) lead_times = [] alert_lead_times = [] n_videos = len(by_video) for vid, items in by_video.items(): items.sort(key=lambda x: -x[0]) # earliest first earliest_any = None earliest_alert = None for tta_val, pr in items: if earliest_any is None and pr in (1, 2): earliest_any = tta_val if earliest_alert is None and pr == 2: earliest_alert = tta_val if earliest_any is not None: lead_times.append(earliest_any) if earliest_alert is not None: alert_lead_times.append(earliest_alert) mean_lead = float(np.mean(lead_times)) if lead_times else 0.0 mean_lead_alrt = float(np.mean(alert_lead_times)) if alert_lead_times else 0.0 cov_any = _r(len(lead_times), n_videos) cov_alert = _r(len(alert_lead_times), n_videos) return { "policy_score": policy_score, "ego_alert_recall": ego_alert, "non_ego_noalert": non_ego_noalert, "safe_neg_silent": safe_silent, "safe_neg_alert_leak": fa, "burden_non_ego": burden_non_ego, "binary_ap": ap, "binary_precision": prec, "binary_recall": rec, "binary_f1": f1, "lead_time_mean": mean_lead, "lead_time_coverage": cov_any, "alert_lead_time_mean": mean_lead_alrt, "alert_lead_time_coverage": cov_alert, "n_samples": int(len(labels)), "n_ego_videos": n_videos, }