VLAlert / training /Policy /threshold_analysis.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
13.2 kB
#!/usr/bin/env python3
"""
Small improvements: fine-grained threshold search, TTA-conditioned analysis,
and v3 + v5_mono ensemble.
Runs entirely on CPU using belief cache + lightweight policy heads.
No full model loading needed.
Usage:
python -m training.Policy.threshold_analysis \
--label_dir data/policy_labels \
--belief_cache_dir data/belief_cache \
--v3_ckpt checkpoints/Policy/policy_warmstart_v3/best \
--v5_ckpt checkpoints/Policy/policy_warmstart_v5_mono/best \
--output_dir eval_results/threshold_analysis
"""
from __future__ import annotations
import argparse
import json
import logging
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from lkalert.models.components import PolicyHead, HierarchicalPolicyHead
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("Policy.threshold")
def load_val_data(label_dir: str, cache_dir: str):
"""Load val labels and belief cache."""
with open(Path(label_dir) / "val.json") as f:
data = json.load(f)
samples = data["samples"]
cache = torch.load(Path(cache_dir) / "val.pt", map_location="cpu", weights_only=True)
beliefs = cache["beliefs"] # [N, 2048]
tta_means = cache["tta_means"] # [N]
tta_vars = cache["tta_vars"] # [N]
labels = np.array([s["action_label"] for s in samples])
cats = np.array([s["category"] for s in samples])
ttas = np.array([s["tta_raw"] for s in samples])
vids = [s["video_id"] for s in samples]
return beliefs, tta_means, tta_vars, labels, cats, ttas, vids
@torch.no_grad()
def get_v3_probs(beliefs, tta_means, tta_vars, ckpt_dir):
"""Forward through v3 PolicyHead β†’ softmax probs [N, 3]."""
head = PolicyHead(hidden_dim=int(beliefs.shape[-1]))
sd = torch.load(Path(ckpt_dir) / "policy_head.pt", map_location="cpu")
head.load_state_dict(sd)
head.eval()
B = beliefs.shape[0]
prev_action = torch.zeros(B, dtype=torch.long)
logits = head(beliefs, tta_means, tta_vars, prev_action)
return F.softmax(logits, dim=-1).numpy()
@torch.no_grad()
def get_v5_probs(beliefs, tta_means, tta_vars, ckpt_dir):
"""Forward through v5 HierarchicalPolicyHead β†’ 3-class probs [N, 3]."""
head = HierarchicalPolicyHead(hidden_dim=int(beliefs.shape[-1]))
sd = torch.load(Path(ckpt_dir) / "policy_head.pt", map_location="cpu")
head.load_state_dict(sd)
head.eval()
B = beliefs.shape[0]
prev_action = torch.zeros(B, dtype=torch.long)
alert_logit, danger_logit = head(beliefs, tta_means, tta_vars, prev_action)
p_alert = torch.sigmoid(alert_logit).numpy()
p_danger = torch.sigmoid(danger_logit).numpy()
p_silent = 1.0 - p_danger
p_observe = np.clip(p_danger - p_alert, 0.0, None)
probs = np.stack([p_silent, p_observe, p_alert], axis=-1)
probs = probs / probs.sum(axis=-1, keepdims=True).clip(1e-8)
return probs
def policy_metrics(preds, labels, cats):
"""Compute PolicyScore and sub-metrics."""
ego_mask = cats == "ego_positive"
ne_mask = cats == "non_ego"
sn_mask = cats == "safe_neg"
ego_alert_mask = ego_mask & (labels == 2)
ego_recall = float((preds[ego_alert_mask] == 2).mean()) if ego_alert_mask.sum() > 0 else 0.0
ne_noalert = float((preds[ne_mask] != 2).mean()) if ne_mask.sum() > 0 else 0.0
sn_silent = float((preds[sn_mask] == 0).mean()) if sn_mask.sum() > 0 else 0.0
sn_alert = float((preds[sn_mask] == 2).mean()) if sn_mask.sum() > 0 else 0.0
# PolicyScore v3 (safety-first): 0.65 * ego_recall + 0.25 * safe_silent - 0.15 * safe_alert
score = 0.65 * ego_recall + 0.25 * sn_silent - 0.15 * sn_alert
return {
"policy_score": score,
"ego_alert_recall": ego_recall,
"non_ego_noalert_rate": ne_noalert,
"safe_neg_silent_rate": sn_silent,
"safe_neg_alert_rate": sn_alert,
}
def binary_ap(probs, labels):
"""Compute binary AP from P(ALERT)."""
from sklearn.metrics import average_precision_score
true = (labels == 2).astype(int)
return float(average_precision_score(true, probs[:, 2])) if true.sum() > 0 else 0.0
# ═══════════════════════════════════════════════════════════════════════════════
# Analysis 1: Fine-grained threshold grid for v5
# ═══════════════════════════════════════════════════════════════════════════════
def fine_threshold_grid(probs_v5, labels, cats, raw_alert, raw_danger):
"""
For v5 hierarchical: search tau_a x tau_d at 0.01 resolution.
probs_v5 is 3-class probs, but we reconstruct p_alert and p_danger.
"""
# Reconstruct from 3-class probs (approximate)
p_alert = raw_alert
p_danger = raw_danger
best_score = -1
best_ta, best_td = 0.5, 0.5
results_grid = []
for ta in np.arange(0.20, 0.81, 0.01):
for td in np.arange(0.10, 0.81, 0.01):
preds = np.zeros(len(labels), dtype=int)
preds[p_danger > td] = 1
preds[p_alert > ta] = 2
m = policy_metrics(preds, labels, cats)
if m["policy_score"] > best_score:
best_score = m["policy_score"]
best_ta, best_td = ta, td
best_m = m
logger.info(f"Fine threshold: best tau_a={best_ta:.2f} tau_d={best_td:.2f} "
f"PolicyScore={best_score:.4f}")
return {
"best_tau_alert": round(best_ta, 2),
"best_tau_danger": round(best_td, 2),
"best_policy_score": best_score,
**{f"best_{k}": v for k, v in best_m.items() if k != "policy_score"},
}
# ═══════════════════════════════════════════════════════════════════════════════
# Analysis 2: TTA-conditioned thresholds
# ═══════════════════════════════════════════════════════════════════════════════
def tta_conditioned_analysis(probs, labels, cats, ttas):
"""Analyze how optimal thresholds vary with TTA."""
p_alert = probs[:, 2]
buckets = [
("tta_0_2", (0, 2)),
("tta_2_4", (2, 4)),
("tta_4_6", (4, 6)),
("tta_6_inf", (6, 100)),
("no_tta", (-2, -0.5)), # safe_neg and non_ego with tta=-1
]
results = {}
for name, (lo, hi) in buckets:
mask = (ttas >= lo) & (ttas < hi)
n = mask.sum()
if n < 10:
continue
sub_labels = labels[mask]
sub_cats = cats[mask]
sub_palert = p_alert[mask]
# Find best threshold for this bucket
best_t, best_s = 0.5, -1
for t in np.arange(0.1, 0.9, 0.01):
preds = np.where(sub_palert > t, 2, 0).astype(int)
ego_alert = (sub_cats == "ego_positive") & (sub_labels == 2)
recall = float((preds[ego_alert] == 2).mean()) if ego_alert.sum() > 0 else 0.0
sn = sub_cats == "safe_neg"
silent = float((preds[sn] == 0).mean()) if sn.sum() > 0 else 0.0
score = 0.7 * recall + 0.3 * silent
if score > best_s:
best_s = score
best_t = t
results[name] = {
"n_samples": int(n),
"n_alert": int((sub_labels == 2).sum()),
"best_threshold": round(best_t, 2),
"mean_p_alert": float(sub_palert.mean()),
"std_p_alert": float(sub_palert.std()),
}
return results
# ═══════════════════════════════════════════════════════════════════════════════
# Analysis 3: Ensemble v3 + v5
# ═══════════════════════════════════════════════════════════════════════════════
def ensemble_analysis(probs_v3, probs_v5, labels, cats):
"""Weighted ensemble of v3 and v5 probabilities."""
results = {}
for w5 in np.arange(0.0, 1.01, 0.1):
w3 = 1.0 - w5
ens = w3 * probs_v3 + w5 * probs_v5
preds = ens.argmax(axis=1)
m = policy_metrics(preds, labels, cats)
ap = binary_ap(ens, labels)
key = f"w3={w3:.1f}_w5={w5:.1f}"
results[key] = {**m, "binary_ap": ap}
if abs(w5 - 0.5) < 0.01:
logger.info(f" Ensemble 50/50: PolicyScore={m['policy_score']:.4f} AP={ap:.4f}")
# Find best
best_key = max(results, key=lambda k: results[k]["policy_score"])
results["best"] = {"config": best_key, **results[best_key]}
logger.info(f" Best ensemble: {best_key} PolicyScore={results[best_key]['policy_score']:.4f}")
return results
def main():
parser = argparse.ArgumentParser("threshold_analysis")
parser.add_argument("--label_dir", default="data/policy_labels")
parser.add_argument("--belief_cache_dir", default="data/belief_cache")
parser.add_argument("--v3_ckpt", default="checkpoints/Policy/policy_warmstart_v3/best")
parser.add_argument("--v5_ckpt", default="checkpoints/Policy/policy_warmstart_v5_mono/best")
parser.add_argument("--output_dir", default="eval_results/threshold_analysis")
args = parser.parse_args()
logger.info("Loading val data...")
beliefs, tta_means, tta_vars, labels, cats, ttas, vids = load_val_data(
args.label_dir, args.belief_cache_dir
)
# Get predictions from both models
logger.info("Running v3 PolicyHead...")
probs_v3 = get_v3_probs(beliefs, tta_means, tta_vars, args.v3_ckpt)
m_v3 = policy_metrics(probs_v3.argmax(axis=1), labels, cats)
logger.info(f" v3 PolicyScore={m_v3['policy_score']:.4f} AP={binary_ap(probs_v3, labels):.4f}")
logger.info("Running v5 HierarchicalPolicyHead...")
# Also get raw sigmoid outputs for threshold analysis
head_v5 = HierarchicalPolicyHead(hidden_dim=int(beliefs.shape[-1]))
sd = torch.load(Path(args.v5_ckpt) / "policy_head.pt", map_location="cpu")
head_v5.load_state_dict(sd)
head_v5.eval()
with torch.no_grad():
prev = torch.zeros(beliefs.shape[0], dtype=torch.long)
al, dl = head_v5(beliefs, tta_means, tta_vars, prev)
raw_alert = torch.sigmoid(al).numpy()
raw_danger = torch.sigmoid(dl).numpy()
probs_v5 = get_v5_probs(beliefs, tta_means, tta_vars, args.v5_ckpt)
m_v5 = policy_metrics(probs_v5.argmax(axis=1), labels, cats)
logger.info(f" v5 PolicyScore={m_v5['policy_score']:.4f} AP={binary_ap(probs_v5, labels):.4f}")
all_results = {}
# ── 1) Fine-grained threshold grid ──
logger.info("\n=== Fine-grained threshold grid (v5) ===")
all_results["fine_threshold"] = fine_threshold_grid(
probs_v5, labels, cats, raw_alert, raw_danger
)
# ── 2) TTA-conditioned analysis ──
logger.info("\n=== TTA-conditioned threshold analysis ===")
all_results["tta_conditioned"] = tta_conditioned_analysis(probs_v5, labels, cats, ttas)
for bucket, info in all_results["tta_conditioned"].items():
logger.info(f" {bucket}: n={info['n_samples']} n_alert={info['n_alert']} "
f"best_t={info['best_threshold']} mean_p={info['mean_p_alert']:.3f}")
# ── 3) Ensemble ──
logger.info("\n=== Ensemble analysis (v3 + v5) ===")
all_results["ensemble"] = ensemble_analysis(probs_v3, probs_v5, labels, cats)
# ── 4) Summary ──
all_results["summary"] = {
"v3_policy_score": m_v3["policy_score"],
"v3_binary_ap": binary_ap(probs_v3, labels),
"v5_policy_score": m_v5["policy_score"],
"v5_binary_ap": binary_ap(probs_v5, labels),
"v5_fine_threshold_score": all_results["fine_threshold"]["best_policy_score"],
"ensemble_best_score": all_results["ensemble"]["best"]["policy_score"],
"ensemble_best_ap": all_results["ensemble"]["best"]["binary_ap"],
}
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
with open(out_dir / "threshold_analysis.json", "w") as f:
json.dump(all_results, f, indent=2)
logger.info(f"\nResults saved to {out_dir / 'threshold_analysis.json'}")
logger.info("\n" + "=" * 60)
logger.info("SUMMARY")
logger.info("=" * 60)
for k, v in all_results["summary"].items():
logger.info(f" {k}: {v:.4f}")
logger.info("=" * 60)
if __name__ == "__main__":
main()