| |
| """ |
| Post-processing transforms for PolicyModel logits. |
| |
| Three composable operations (all numpy-based, CPU-cheap): |
| 1. Per-class temperature scaling β L-BFGS-fit on train split (3 params) |
| 2. Temporal smoothing β EMA over per-video sorted-by-TTA sequence |
| 3. Non-ego bias β reduces ALERT when OBSERVE probability dominates |
| |
| Applied in order: raw_logits β per-class-T β smoothing β non-ego-bias β (+global_bias) |
| """ |
| from __future__ import annotations |
|
|
| from collections import defaultdict |
| from typing import Sequence, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| |
| def fit_per_class_temperature( |
| logits: np.ndarray, |
| labels: np.ndarray, |
| init: Tuple[float, float, float] = (1.0, 1.0, 1.0), |
| max_iter: int = 200, |
| device: str = "cpu", |
| ) -> np.ndarray: |
| """ |
| Fit 3 positive temperatures T = (T_SILENT, T_OBSERVE, T_ALERT) |
| by minimizing cross-entropy on (logits / T). |
| |
| logits : [N, 3] float |
| labels : [N] int in {0, 1, 2} |
| Returns T : [3] float (positive) |
| """ |
| lg = torch.as_tensor(logits, dtype=torch.float32, device=device) |
| lb = torch.as_tensor(labels, dtype=torch.long, device=device) |
| log_T = torch.tensor(np.log(init), dtype=torch.float32, |
| device=device, requires_grad=True) |
| opt = torch.optim.LBFGS([log_T], lr=0.1, max_iter=max_iter, |
| tolerance_grad=1e-7, line_search_fn="strong_wolfe") |
|
|
| def closure(): |
| opt.zero_grad() |
| T = log_T.exp().clamp(min=1e-3, max=1e3) |
| cal = lg / T.unsqueeze(0) |
| loss = F.cross_entropy(cal, lb) |
| loss.backward() |
| return loss |
|
|
| opt.step(closure) |
| T = log_T.detach().exp().clamp(min=0.1, max=10.0).cpu().numpy() |
| return T.astype(np.float32) |
|
|
|
|
| def apply_per_class_temperature(logits: np.ndarray, T: Sequence[float]) -> np.ndarray: |
| """logits [N,3] Γ· T [3] (broadcast).""" |
| return (logits / np.asarray(T, dtype=np.float32)[None, :]).astype(np.float32) |
|
|
|
|
| |
| def temporal_smooth( |
| logits: np.ndarray, |
| video_ids: Sequence[str], |
| ttas: np.ndarray, |
| window: int = 3, |
| mode: str = "ema", |
| alpha: float = 0.5, |
| ) -> np.ndarray: |
| """ |
| For each video, sort its samples by tta DESC (earliest = largest tta first) |
| and apply causal smoothing. Samples with tta == -1 (non_ego / safe_neg) have |
| no meaningful temporal order β kept as-is. |
| |
| logits : [N, 3] |
| video_ids: list[str] length N |
| ttas : [N] float; negative values β no smoothing |
| window : for mode="mean", the sliding window size |
| alpha : for mode="ema", weight of current frame (past gets 1-alpha) |
| """ |
| out = logits.copy() |
| by_vid: dict[str, list[int]] = defaultdict(list) |
| for i, v in enumerate(video_ids): |
| if ttas[i] >= 0: |
| by_vid[v].append(i) |
|
|
| for v, idxs in by_vid.items(): |
| if len(idxs) < 2: |
| continue |
| |
| idxs_sorted = sorted(idxs, key=lambda i: -float(ttas[i])) |
| L = np.stack([logits[i] for i in idxs_sorted]) |
|
|
| if mode == "mean": |
| sm = L.copy() |
| for pos in range(1, len(L)): |
| lo = max(0, pos - window + 1) |
| sm[pos] = L[lo:pos + 1].mean(axis=0) |
| elif mode == "ema": |
| sm = np.empty_like(L) |
| sm[0] = L[0] |
| for pos in range(1, len(L)): |
| sm[pos] = alpha * L[pos] + (1.0 - alpha) * sm[pos - 1] |
| else: |
| raise ValueError(f"Unknown mode: {mode}") |
|
|
| for k, i in enumerate(idxs_sorted): |
| out[i] = sm[k] |
| return out |
|
|
|
|
| |
| def non_ego_bias(logits: np.ndarray, alpha: float = 0.5) -> np.ndarray: |
| """ |
| Reduce ALERT logit when OBSERVE probability dominates. |
| Uses only information available at inference (no ground-truth leakage). |
| |
| Mechanism: |
| p = softmax(logits) |
| logits[:, 2] -= alpha * p[:, 1] |
| |
| alpha in [0, 2] is the magnitude of the nudge. |
| """ |
| if alpha == 0.0: |
| return logits |
| p = np.exp(logits - logits.max(axis=1, keepdims=True)) |
| p /= p.sum(axis=1, keepdims=True) |
| out = logits.copy() |
| out[:, 2] = out[:, 2] - alpha * p[:, 1] |
| return out |
|
|
|
|
| |
| def apply_postproc( |
| logits: np.ndarray, |
| video_ids: Sequence[str], |
| ttas: np.ndarray, |
| T_per_class: Sequence[float] | None = None, |
| smooth_window: int = 1, |
| smooth_mode: str = "ema", |
| smooth_alpha: float = 0.5, |
| non_ego_alpha: float = 0.0, |
| alert_bias: float = 0.0, |
| ) -> np.ndarray: |
| lg = logits.astype(np.float32).copy() |
| if T_per_class is not None: |
| lg = apply_per_class_temperature(lg, T_per_class) |
| if smooth_window > 1 or smooth_mode == "ema": |
| lg = temporal_smooth(lg, video_ids, ttas, |
| window=smooth_window, |
| mode=smooth_mode, alpha=smooth_alpha) |
| if non_ego_alpha > 0.0: |
| lg = non_ego_bias(lg, alpha=non_ego_alpha) |
| if alert_bias != 0.0: |
| lg[:, 2] = lg[:, 2] + alert_bias |
| return lg |
|
|
|
|
| |
| def compute_metrics( |
| logits: np.ndarray, |
| labels: np.ndarray, |
| cats: np.ndarray, |
| ttas: np.ndarray, |
| video_ids: Sequence[str], |
| ) -> dict: |
| """Mirror of baseline.comparison.eval_all.compute_all_metrics.""" |
| preds = logits.argmax(axis=1) |
| |
| p = np.exp(logits - logits.max(axis=1, keepdims=True)) |
| p /= p.sum(axis=1, keepdims=True) |
| prob_alert = p[:, 2] |
|
|
| def _r(n, d): return float(n) / float(d) if d > 0 else 0.0 |
|
|
| ego_mask = cats == "ego_positive" |
| safe_mask = cats == "safe_neg" |
| ne_mask = cats == "non_ego" |
|
|
| ego_alert = _r(((preds == 2) & ego_mask & (labels == 2)).sum(), |
| (ego_mask & (labels == 2)).sum()) |
| non_ego_noalert = _r(((preds != 2) & ne_mask).sum(), ne_mask.sum()) |
| safe_silent = _r(((preds == 0) & safe_mask).sum(), safe_mask.sum()) |
| fa = _r(((preds == 2) & safe_mask).sum(), safe_mask.sum()) |
| burden_non_ego = _r(((preds == 2) & ne_mask).sum(), ne_mask.sum()) |
| |
| policy_score = 0.65 * ego_alert + 0.25 * safe_silent - 0.15 * fa |
|
|
| |
| from sklearn.metrics import average_precision_score |
| binary_true = (labels == 2).astype(int) |
| try: |
| ap = float(average_precision_score(binary_true, prob_alert)) |
| except Exception: |
| ap = 0.0 |
| tp = int(((preds == 2) & (labels == 2)).sum()) |
| fp = int(((preds == 2) & (labels != 2)).sum()) |
| fn = int(((preds != 2) & (labels == 2)).sum()) |
| prec = _r(tp, tp + fp) |
| rec = _r(tp, tp + fn) |
| f1 = _r(2 * prec * rec, prec + rec) |
|
|
| |
| ego_idx = np.where(ego_mask)[0] |
| by_video: dict[str, list] = defaultdict(list) |
| for i in ego_idx: |
| by_video[video_ids[i]].append((float(ttas[i]), int(preds[i]))) |
| lead_times = [] |
| alert_lead_times = [] |
| n_videos = len(by_video) |
| for vid, items in by_video.items(): |
| items.sort(key=lambda x: -x[0]) |
| earliest_any = None |
| earliest_alert = None |
| for tta_val, pr in items: |
| if earliest_any is None and pr in (1, 2): earliest_any = tta_val |
| if earliest_alert is None and pr == 2: earliest_alert = tta_val |
| if earliest_any is not None: lead_times.append(earliest_any) |
| if earliest_alert is not None: alert_lead_times.append(earliest_alert) |
| mean_lead = float(np.mean(lead_times)) if lead_times else 0.0 |
| mean_lead_alrt = float(np.mean(alert_lead_times)) if alert_lead_times else 0.0 |
| cov_any = _r(len(lead_times), n_videos) |
| cov_alert = _r(len(alert_lead_times), n_videos) |
|
|
| return { |
| "policy_score": policy_score, |
| "ego_alert_recall": ego_alert, |
| "non_ego_noalert": non_ego_noalert, |
| "safe_neg_silent": safe_silent, |
| "safe_neg_alert_leak": fa, |
| "burden_non_ego": burden_non_ego, |
| "binary_ap": ap, |
| "binary_precision": prec, |
| "binary_recall": rec, |
| "binary_f1": f1, |
| "lead_time_mean": mean_lead, |
| "lead_time_coverage": cov_any, |
| "alert_lead_time_mean": mean_lead_alrt, |
| "alert_lead_time_coverage": cov_alert, |
| "n_samples": int(len(labels)), |
| "n_ego_videos": n_videos, |
| } |
|
|