VLAlert / training /Policy /postproc.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
9.91 kB
#!/usr/bin/env python3
"""
Post-processing transforms for PolicyModel logits.
Three composable operations (all numpy-based, CPU-cheap):
1. Per-class temperature scaling β€” L-BFGS-fit on train split (3 params)
2. Temporal smoothing β€” EMA over per-video sorted-by-TTA sequence
3. Non-ego bias β€” reduces ALERT when OBSERVE probability dominates
Applied in order: raw_logits β†’ per-class-T β†’ smoothing β†’ non-ego-bias β†’ (+global_bias)
"""
from __future__ import annotations
from collections import defaultdict
from typing import Sequence, Tuple
import numpy as np
import torch
import torch.nn.functional as F
# ─────────────────────────── 1. Per-class temperature ─────────────────────────
def fit_per_class_temperature(
logits: np.ndarray,
labels: np.ndarray,
init: Tuple[float, float, float] = (1.0, 1.0, 1.0),
max_iter: int = 200,
device: str = "cpu",
) -> np.ndarray:
"""
Fit 3 positive temperatures T = (T_SILENT, T_OBSERVE, T_ALERT)
by minimizing cross-entropy on (logits / T).
logits : [N, 3] float
labels : [N] int in {0, 1, 2}
Returns T : [3] float (positive)
"""
lg = torch.as_tensor(logits, dtype=torch.float32, device=device)
lb = torch.as_tensor(labels, dtype=torch.long, device=device)
log_T = torch.tensor(np.log(init), dtype=torch.float32,
device=device, requires_grad=True)
opt = torch.optim.LBFGS([log_T], lr=0.1, max_iter=max_iter,
tolerance_grad=1e-7, line_search_fn="strong_wolfe")
def closure():
opt.zero_grad()
T = log_T.exp().clamp(min=1e-3, max=1e3)
cal = lg / T.unsqueeze(0) # broadcast [1,3]
loss = F.cross_entropy(cal, lb)
loss.backward()
return loss
opt.step(closure)
T = log_T.detach().exp().clamp(min=0.1, max=10.0).cpu().numpy()
return T.astype(np.float32)
def apply_per_class_temperature(logits: np.ndarray, T: Sequence[float]) -> np.ndarray:
"""logits [N,3] Γ· T [3] (broadcast)."""
return (logits / np.asarray(T, dtype=np.float32)[None, :]).astype(np.float32)
# ─────────────────────────── 2. Temporal smoothing ────────────────────────────
def temporal_smooth(
logits: np.ndarray,
video_ids: Sequence[str],
ttas: np.ndarray,
window: int = 3,
mode: str = "ema", # "ema" | "mean"
alpha: float = 0.5, # EMA weight on current frame
) -> np.ndarray:
"""
For each video, sort its samples by tta DESC (earliest = largest tta first)
and apply causal smoothing. Samples with tta == -1 (non_ego / safe_neg) have
no meaningful temporal order β†’ kept as-is.
logits : [N, 3]
video_ids: list[str] length N
ttas : [N] float; negative values β‡’ no smoothing
window : for mode="mean", the sliding window size
alpha : for mode="ema", weight of current frame (past gets 1-alpha)
"""
out = logits.copy()
by_vid: dict[str, list[int]] = defaultdict(list)
for i, v in enumerate(video_ids):
if ttas[i] >= 0: # only ego-positive samples have a time axis
by_vid[v].append(i)
for v, idxs in by_vid.items():
if len(idxs) < 2:
continue
# largest tta first = earliest in time
idxs_sorted = sorted(idxs, key=lambda i: -float(ttas[i]))
L = np.stack([logits[i] for i in idxs_sorted]) # [n, 3]
if mode == "mean":
sm = L.copy()
for pos in range(1, len(L)):
lo = max(0, pos - window + 1)
sm[pos] = L[lo:pos + 1].mean(axis=0)
elif mode == "ema":
sm = np.empty_like(L)
sm[0] = L[0]
for pos in range(1, len(L)):
sm[pos] = alpha * L[pos] + (1.0 - alpha) * sm[pos - 1]
else:
raise ValueError(f"Unknown mode: {mode}")
for k, i in enumerate(idxs_sorted):
out[i] = sm[k]
return out
# ─────────────────────────── 3. Non-ego bias ──────────────────────────────────
def non_ego_bias(logits: np.ndarray, alpha: float = 0.5) -> np.ndarray:
"""
Reduce ALERT logit when OBSERVE probability dominates.
Uses only information available at inference (no ground-truth leakage).
Mechanism:
p = softmax(logits)
logits[:, 2] -= alpha * p[:, 1]
alpha in [0, 2] is the magnitude of the nudge.
"""
if alpha == 0.0:
return logits
p = np.exp(logits - logits.max(axis=1, keepdims=True))
p /= p.sum(axis=1, keepdims=True)
out = logits.copy()
out[:, 2] = out[:, 2] - alpha * p[:, 1]
return out
# ─────────────────────────── All-in-one helper ────────────────────────────────
def apply_postproc(
logits: np.ndarray,
video_ids: Sequence[str],
ttas: np.ndarray,
T_per_class: Sequence[float] | None = None,
smooth_window: int = 1, # 1 β‡’ no smoothing
smooth_mode: str = "ema",
smooth_alpha: float = 0.5,
non_ego_alpha: float = 0.0,
alert_bias: float = 0.0,
) -> np.ndarray:
lg = logits.astype(np.float32).copy()
if T_per_class is not None:
lg = apply_per_class_temperature(lg, T_per_class)
if smooth_window > 1 or smooth_mode == "ema":
lg = temporal_smooth(lg, video_ids, ttas,
window=smooth_window,
mode=smooth_mode, alpha=smooth_alpha)
if non_ego_alpha > 0.0:
lg = non_ego_bias(lg, alpha=non_ego_alpha)
if alert_bias != 0.0:
lg[:, 2] = lg[:, 2] + alert_bias
return lg
# ─────────────────────────── Metric computation ───────────────────────────────
def compute_metrics(
logits: np.ndarray, # [N, 3]
labels: np.ndarray, # [N] in {0,1,2}
cats: np.ndarray, # [N] str
ttas: np.ndarray, # [N] float
video_ids: Sequence[str],
) -> dict:
"""Mirror of baseline.comparison.eval_all.compute_all_metrics."""
preds = logits.argmax(axis=1)
# softmax only for prob_alert (for binary AP)
p = np.exp(logits - logits.max(axis=1, keepdims=True))
p /= p.sum(axis=1, keepdims=True)
prob_alert = p[:, 2]
def _r(n, d): return float(n) / float(d) if d > 0 else 0.0
ego_mask = cats == "ego_positive"
safe_mask = cats == "safe_neg"
ne_mask = cats == "non_ego"
ego_alert = _r(((preds == 2) & ego_mask & (labels == 2)).sum(),
(ego_mask & (labels == 2)).sum())
non_ego_noalert = _r(((preds != 2) & ne_mask).sum(), ne_mask.sum())
safe_silent = _r(((preds == 0) & safe_mask).sum(), safe_mask.sum())
fa = _r(((preds == 2) & safe_mask).sum(), safe_mask.sum())
burden_non_ego = _r(((preds == 2) & ne_mask).sum(), ne_mask.sum())
# PolicyScore v3 (safety-first): 0.65*ego_recall + 0.25*safe_silent - 0.15*safe_alert
policy_score = 0.65 * ego_alert + 0.25 * safe_silent - 0.15 * fa
# binary AP / F1
from sklearn.metrics import average_precision_score
binary_true = (labels == 2).astype(int)
try:
ap = float(average_precision_score(binary_true, prob_alert))
except Exception:
ap = 0.0
tp = int(((preds == 2) & (labels == 2)).sum())
fp = int(((preds == 2) & (labels != 2)).sum())
fn = int(((preds != 2) & (labels == 2)).sum())
prec = _r(tp, tp + fp)
rec = _r(tp, tp + fn)
f1 = _r(2 * prec * rec, prec + rec)
# lead time (OBSERVEβˆͺALERT response, per ego video)
ego_idx = np.where(ego_mask)[0]
by_video: dict[str, list] = defaultdict(list)
for i in ego_idx:
by_video[video_ids[i]].append((float(ttas[i]), int(preds[i])))
lead_times = []
alert_lead_times = []
n_videos = len(by_video)
for vid, items in by_video.items():
items.sort(key=lambda x: -x[0]) # earliest first
earliest_any = None
earliest_alert = None
for tta_val, pr in items:
if earliest_any is None and pr in (1, 2): earliest_any = tta_val
if earliest_alert is None and pr == 2: earliest_alert = tta_val
if earliest_any is not None: lead_times.append(earliest_any)
if earliest_alert is not None: alert_lead_times.append(earliest_alert)
mean_lead = float(np.mean(lead_times)) if lead_times else 0.0
mean_lead_alrt = float(np.mean(alert_lead_times)) if alert_lead_times else 0.0
cov_any = _r(len(lead_times), n_videos)
cov_alert = _r(len(alert_lead_times), n_videos)
return {
"policy_score": policy_score,
"ego_alert_recall": ego_alert,
"non_ego_noalert": non_ego_noalert,
"safe_neg_silent": safe_silent,
"safe_neg_alert_leak": fa,
"burden_non_ego": burden_non_ego,
"binary_ap": ap,
"binary_precision": prec,
"binary_recall": rec,
"binary_f1": f1,
"lead_time_mean": mean_lead,
"lead_time_coverage": cov_any,
"alert_lead_time_mean": mean_lead_alrt,
"alert_lead_time_coverage": cov_alert,
"n_samples": int(len(labels)),
"n_ego_videos": n_videos,
}