VLAlert / training /Policy /paper_eval.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
12.8 kB
#!/usr/bin/env python3
"""
Paper Evaluation Script — 生成论文所需的全部指标表格
对比以下模型(所有都在同一 val set 上评估):
1. LKAlert-Binary (obs→alert) ← Ablation baseline
2. LKAlert-v2 (focal 0.1/0.3/0.6, 43% FA)
3. LKAlert-v3 (focal 0.2/0.3/0.5, fixed FA) ← 主模型
输出(LaTeX + JSON):
Table 1: Per-class metrics (precision, recall, F1)
Table 2: Policy score decomposition
Table 3: OBSERVE lead-time advantage
Table 4: False alarm vs recall tradeoff
使用方法:
python -m training.Policy.paper_eval \
--sft_checkpoint checkpoints/SFT/sft_v2/best \
--label_dir data/policy_labels \
--belief_cache_dir data/belief_cache \
--models policy_warmstart_v2 checkpoints/Policy/policy_warmstart_v2/best \
policy_warmstart_v3 checkpoints/Policy/policy_warmstart_v3/best \
binary_obs2alert checkpoints/Policy/policy_binary_obs2alert/best \
--output_dir eval_results/paper_tables
"""
from __future__ import annotations
import argparse
import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix
from torch.utils.data import DataLoader
from tqdm import tqdm
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
from training.Policy.policy_model import PolicyModel
from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn
from training.Policy.warm_start_trainer import compute_policy_score
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("Policy.paper_eval")
ACTION_NAMES = {0: "SILENT", 1: "OBSERVE", 2: "ALERT"}
@torch.no_grad()
def evaluate_model(
model: PolicyModel,
loader: DataLoader,
device: torch.device,
merge_observe: str = None, # "alert" | "silent" | None
) -> dict:
"""Full evaluation: per-class metrics + policy score."""
model.eval()
all_true, all_pred = [], []
all_probs = []
categories, tta_raws = [], []
for batch in loader:
if "beliefs" in batch:
logits = model.forward_cached(
batch["beliefs"].to(device),
batch["tta_means"].to(device),
batch["tta_vars"].to(device),
)
else:
logits = model(batch["images"], batch["metadata"])
probs = F.softmax(logits, dim=-1).cpu().numpy()
preds = logits.argmax(dim=-1).cpu().numpy()
trues = batch["action_labels"].numpy()
# Optional: merge OBSERVE at eval time too
if merge_observe == "alert":
trues = np.where(trues == 1, 2, trues)
preds = np.where(preds == 1, 2, preds)
elif merge_observe == "silent":
trues = np.where(trues == 1, 0, trues)
preds = np.where(preds == 1, 0, preds)
all_true.extend(trues.tolist())
all_pred.extend(preds.tolist())
all_probs.extend(probs.tolist())
categories.extend(batch["categories"])
tta_raws.extend(batch["tta_raws"].tolist())
all_true = np.array(all_true)
all_pred = np.array(all_pred)
# ── sklearn classification report ─────────────────────────────────────────
present = sorted(set(all_true.tolist()) | set(all_pred.tolist()))
names = [ACTION_NAMES.get(i, str(i)) for i in present]
report = classification_report(all_true, all_pred,
labels=present, target_names=names,
output_dict=True, zero_division=0)
# ── policy score ──────────────────────────────────────────────────────────
cats = np.array(categories)
def _ratio(num, den):
return float(num / den) if den > 0 else 0.0
# ego_positive: ALERT recall (label==2, pred==2)
ego_mask = cats == "ego_positive"
ego_preds = all_pred[ego_mask]
ego_trues = all_true[ego_mask]
alert_true = ego_preds[ego_trues == 2]
ego_alert_recall = _ratio((alert_true == 2).sum(), len(alert_true))
# non_ego: fraction predicted NOT ALERT
ne_mask = cats == "non_ego"
ne_preds = all_pred[ne_mask]
non_ego_noalert_rate = _ratio((ne_preds != 2).sum(), len(ne_preds))
# safe_neg: fraction predicted SILENT
sn_mask = cats == "safe_neg"
sn_preds = all_pred[sn_mask]
safe_neg_silent_rate = _ratio((sn_preds == 0).sum(), len(sn_preds))
safe_neg_alert_leak = _ratio((sn_preds == 2).sum(), len(sn_preds))
# PolicyScore v3 (safety-first): 0.65/0.25/0.15 on ego_recall / safe_silent / -safe_alert
policy_score = compute_policy_score(
ego_alert_recall = ego_alert_recall,
safe_neg_silent_rate = safe_neg_silent_rate,
safe_neg_alert_rate = safe_neg_alert_leak,
)
# ── OBSERVE lead time ─────────────────────────────────────────────────────
tta_arr = np.array(tta_raws)
observe_by_tta = {}
for lo in range(0, 10):
mask = ego_mask & (tta_arr >= lo) & (tta_arr < lo + 1)
if mask.sum() > 0:
observe_by_tta[f"{lo}-{lo+1}s"] = float(np.mean(all_pred[mask] == 1))
return {
"classification_report": report,
"policy_score": policy_score,
"ego_alert_recall": ego_alert_recall,
"non_ego_noalert": non_ego_noalert_rate,
"safe_neg_silent": safe_neg_silent_rate,
"safe_neg_alert_leak": safe_neg_alert_leak,
"observe_by_tta": observe_by_tta,
"n_samples": len(all_true),
"label_dist": {ACTION_NAMES[k]: int((all_pred == k).sum())
for k in range(3)},
}
def format_latex_table1(results: Dict[str, dict]) -> str:
"""Per-class precision/recall/F1 comparison table."""
lines = [
r"\begin{table}[h]",
r"\centering",
r"\caption{Per-class Classification Performance (val set)}",
r"\begin{tabular}{l|ccc|ccc|ccc}",
r"\hline",
r"Model & \multicolumn{3}{c|}{SILENT} & \multicolumn{3}{c|}{OBSERVE} & \multicolumn{3}{c}{ALERT} \\",
r" & P & R & F1 & P & R & F1 & P & R & F1 \\",
r"\hline",
]
for name, m in results.items():
rpt = m["classification_report"]
row = [name.replace("_", r"\_")]
for cls in ["SILENT", "OBSERVE", "ALERT"]:
if cls in rpt:
row += [f'{rpt[cls]["precision"]:.3f}',
f'{rpt[cls]["recall"]:.3f}',
f'{rpt[cls]["f1-score"]:.3f}']
else:
row += ["—", "—", "—"]
lines.append(" & ".join(row) + r" \\")
lines += [r"\hline", r"\end{tabular}", r"\end{table}"]
return "\n".join(lines)
def format_latex_table2(results: Dict[str, dict]) -> str:
"""Policy score decomposition table."""
lines = [
r"\begin{table}[h]",
r"\centering",
r"\caption{Policy Score Decomposition}",
r"\begin{tabular}{l|cccc|c}",
r"\hline",
r"Model & Ego-Alert↑ & Non-Ego-NoAlert↑ & Safe-Silent↑ & False-Alarm↓ & Policy Score↑ \\",
r"\hline",
]
for name, m in results.items():
lines.append(
f"{name.replace('_', chr(92)+'_')} & "
f'{m["ego_alert_recall"]:.3f} & '
f'{m["non_ego_noalert"]:.3f} & '
f'{m["safe_neg_silent"]:.3f} & '
f'{m["safe_neg_alert_leak"]:.3f} & '
f'\\textbf{{{m["policy_score"]:.3f}}} \\\\'
)
lines += [r"\hline", r"\end{tabular}", r"\end{table}"]
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser("paper_eval")
parser.add_argument("--sft_checkpoint", required=True)
parser.add_argument("--label_dir", default="data/policy_labels")
parser.add_argument("--belief_cache_dir", default=None)
parser.add_argument("--split", default="val")
parser.add_argument("--models", nargs="+", required=True,
help="Pairs of: <name> <checkpoint_dir> ...")
parser.add_argument("--output_dir", default="eval_results/paper_tables")
args = parser.parse_args()
if len(args.models) % 2 != 0:
raise ValueError("--models must be pairs of <name> <checkpoint_dir>")
model_pairs = [(args.models[i], args.models[i+1])
for i in range(0, len(args.models), 2)]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# ── shared data loader ────────────────────────────────────────────────────
cache_path = None
if args.belief_cache_dir:
p = Path(args.belief_cache_dir) / f"{args.split}.pt"
if p.exists():
cache_path = p
ds = PolicyDataset(
manifests=[Path(args.label_dir) / f"{args.split}.json"],
split=args.split,
belief_cache_path=cache_path,
)
loader = DataLoader(ds, batch_size=512, shuffle=False,
num_workers=4, collate_fn=policy_collate_fn)
logger.info(f"Val set: {len(ds)} samples")
# ── evaluate each model ───────────────────────────────────────────────────
all_results: Dict[str, dict] = {}
for model_name, ckpt_dir in model_pairs:
logger.info(f"\n{'='*50}")
logger.info(f"Evaluating: {model_name}")
logger.info(f" Checkpoint: {ckpt_dir}")
model = PolicyModel(args.sft_checkpoint, use_bf16=True)
model.load_policy_checkpoint(ckpt_dir)
merge = "alert" if "binary" in model_name else None
m = evaluate_model(model, loader, device, merge_observe=merge)
all_results[model_name] = m
logger.info(f" policy_score={m['policy_score']:.4f} "
f"recall={m['ego_alert_recall']:.4f} "
f"FA={m['safe_neg_alert_leak']:.4f}")
# Free GPU memory
del model
torch.cuda.empty_cache()
# ── print tables ──────────────────────────────────────────────────────────
print("\n" + "="*60)
print("TABLE 1: Per-class classification metrics")
print("="*60)
# ASCII version
header = f"{'Model':<30} {'':>3} {'SILENT':>8} {'':>3} {'OBSERVE':>9} {'':>3} {'ALERT':>7}"
print(header)
print(f"{'':>30} P R F1 P R F1 P R F1")
for name, m in all_results.items():
rpt = m["classification_report"]
row = f"{name:<30}"
for cls in ["SILENT", "OBSERVE", "ALERT"]:
if cls in rpt:
row += f" {rpt[cls]['precision']:.2f} {rpt[cls]['recall']:.2f} {rpt[cls]['f1-score']:.2f}"
else:
row += " — — — "
print(row)
print("\n" + "="*60)
print("TABLE 2: Policy score decomposition")
print("="*60)
print(f"{'Model':<30} {'Ego↑':>6} {'NonEgo↑':>8} {'Silent↑':>8} {'FA↓':>6} {'Score↑':>7}")
for name, m in all_results.items():
print(f"{name:<30} {m['ego_alert_recall']:>6.3f} {m['non_ego_noalert']:>8.3f} "
f"{m['safe_neg_silent']:>8.3f} {m['safe_neg_alert_leak']:>6.3f} "
f"{m['policy_score']:>7.4f}")
# ── save JSON and LaTeX ───────────────────────────────────────────────────
with open(out_dir / "all_results.json", "w") as f:
json.dump({k: {kk: vv for kk, vv in v.items()
if kk != "classification_report"}
for k, v in all_results.items()}, f, indent=2)
latex1 = format_latex_table1(all_results)
latex2 = format_latex_table2(all_results)
(out_dir / "table_per_class.tex").write_text(latex1)
(out_dir / "table_policy_score.tex").write_text(latex2)
logger.info(f"\n✅ Results saved → {out_dir}/")
logger.info(f" JSON : all_results.json")
logger.info(f" LaTeX: table_per_class.tex, table_policy_score.tex")
if __name__ == "__main__":
main()