#!/usr/bin/env python3 """ Paper Evaluation Script — 生成论文所需的全部指标表格 对比以下模型(所有都在同一 val set 上评估): 1. LKAlert-Binary (obs→alert) ← Ablation baseline 2. LKAlert-v2 (focal 0.1/0.3/0.6, 43% FA) 3. LKAlert-v3 (focal 0.2/0.3/0.5, fixed FA) ← 主模型 输出(LaTeX + JSON): Table 1: Per-class metrics (precision, recall, F1) Table 2: Policy score decomposition Table 3: OBSERVE lead-time advantage Table 4: False alarm vs recall tradeoff 使用方法: python -m training.Policy.paper_eval \ --sft_checkpoint checkpoints/SFT/sft_v2/best \ --label_dir data/policy_labels \ --belief_cache_dir data/belief_cache \ --models policy_warmstart_v2 checkpoints/Policy/policy_warmstart_v2/best \ policy_warmstart_v3 checkpoints/Policy/policy_warmstart_v3/best \ binary_obs2alert checkpoints/Policy/policy_binary_obs2alert/best \ --output_dir eval_results/paper_tables """ from __future__ import annotations import argparse import json import logging from collections import defaultdict from pathlib import Path from typing import Dict, List import numpy as np import torch import torch.nn.functional as F from sklearn.metrics import classification_report, confusion_matrix from torch.utils.data import DataLoader from tqdm import tqdm import sys sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) from training.Policy.policy_model import PolicyModel from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn from training.Policy.warm_start_trainer import compute_policy_score logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("Policy.paper_eval") ACTION_NAMES = {0: "SILENT", 1: "OBSERVE", 2: "ALERT"} @torch.no_grad() def evaluate_model( model: PolicyModel, loader: DataLoader, device: torch.device, merge_observe: str = None, # "alert" | "silent" | None ) -> dict: """Full evaluation: per-class metrics + policy score.""" model.eval() all_true, all_pred = [], [] all_probs = [] categories, tta_raws = [], [] for batch in loader: if "beliefs" in batch: logits = model.forward_cached( batch["beliefs"].to(device), batch["tta_means"].to(device), batch["tta_vars"].to(device), ) else: logits = model(batch["images"], batch["metadata"]) probs = F.softmax(logits, dim=-1).cpu().numpy() preds = logits.argmax(dim=-1).cpu().numpy() trues = batch["action_labels"].numpy() # Optional: merge OBSERVE at eval time too if merge_observe == "alert": trues = np.where(trues == 1, 2, trues) preds = np.where(preds == 1, 2, preds) elif merge_observe == "silent": trues = np.where(trues == 1, 0, trues) preds = np.where(preds == 1, 0, preds) all_true.extend(trues.tolist()) all_pred.extend(preds.tolist()) all_probs.extend(probs.tolist()) categories.extend(batch["categories"]) tta_raws.extend(batch["tta_raws"].tolist()) all_true = np.array(all_true) all_pred = np.array(all_pred) # ── sklearn classification report ───────────────────────────────────────── present = sorted(set(all_true.tolist()) | set(all_pred.tolist())) names = [ACTION_NAMES.get(i, str(i)) for i in present] report = classification_report(all_true, all_pred, labels=present, target_names=names, output_dict=True, zero_division=0) # ── policy score ────────────────────────────────────────────────────────── cats = np.array(categories) def _ratio(num, den): return float(num / den) if den > 0 else 0.0 # ego_positive: ALERT recall (label==2, pred==2) ego_mask = cats == "ego_positive" ego_preds = all_pred[ego_mask] ego_trues = all_true[ego_mask] alert_true = ego_preds[ego_trues == 2] ego_alert_recall = _ratio((alert_true == 2).sum(), len(alert_true)) # non_ego: fraction predicted NOT ALERT ne_mask = cats == "non_ego" ne_preds = all_pred[ne_mask] non_ego_noalert_rate = _ratio((ne_preds != 2).sum(), len(ne_preds)) # safe_neg: fraction predicted SILENT sn_mask = cats == "safe_neg" sn_preds = all_pred[sn_mask] safe_neg_silent_rate = _ratio((sn_preds == 0).sum(), len(sn_preds)) safe_neg_alert_leak = _ratio((sn_preds == 2).sum(), len(sn_preds)) # PolicyScore v3 (safety-first): 0.65/0.25/0.15 on ego_recall / safe_silent / -safe_alert policy_score = compute_policy_score( ego_alert_recall = ego_alert_recall, safe_neg_silent_rate = safe_neg_silent_rate, safe_neg_alert_rate = safe_neg_alert_leak, ) # ── OBSERVE lead time ───────────────────────────────────────────────────── tta_arr = np.array(tta_raws) observe_by_tta = {} for lo in range(0, 10): mask = ego_mask & (tta_arr >= lo) & (tta_arr < lo + 1) if mask.sum() > 0: observe_by_tta[f"{lo}-{lo+1}s"] = float(np.mean(all_pred[mask] == 1)) return { "classification_report": report, "policy_score": policy_score, "ego_alert_recall": ego_alert_recall, "non_ego_noalert": non_ego_noalert_rate, "safe_neg_silent": safe_neg_silent_rate, "safe_neg_alert_leak": safe_neg_alert_leak, "observe_by_tta": observe_by_tta, "n_samples": len(all_true), "label_dist": {ACTION_NAMES[k]: int((all_pred == k).sum()) for k in range(3)}, } def format_latex_table1(results: Dict[str, dict]) -> str: """Per-class precision/recall/F1 comparison table.""" lines = [ r"\begin{table}[h]", r"\centering", r"\caption{Per-class Classification Performance (val set)}", r"\begin{tabular}{l|ccc|ccc|ccc}", r"\hline", r"Model & \multicolumn{3}{c|}{SILENT} & \multicolumn{3}{c|}{OBSERVE} & \multicolumn{3}{c}{ALERT} \\", r" & P & R & F1 & P & R & F1 & P & R & F1 \\", r"\hline", ] for name, m in results.items(): rpt = m["classification_report"] row = [name.replace("_", r"\_")] for cls in ["SILENT", "OBSERVE", "ALERT"]: if cls in rpt: row += [f'{rpt[cls]["precision"]:.3f}', f'{rpt[cls]["recall"]:.3f}', f'{rpt[cls]["f1-score"]:.3f}'] else: row += ["—", "—", "—"] lines.append(" & ".join(row) + r" \\") lines += [r"\hline", r"\end{tabular}", r"\end{table}"] return "\n".join(lines) def format_latex_table2(results: Dict[str, dict]) -> str: """Policy score decomposition table.""" lines = [ r"\begin{table}[h]", r"\centering", r"\caption{Policy Score Decomposition}", r"\begin{tabular}{l|cccc|c}", r"\hline", r"Model & Ego-Alert↑ & Non-Ego-NoAlert↑ & Safe-Silent↑ & False-Alarm↓ & Policy Score↑ \\", r"\hline", ] for name, m in results.items(): lines.append( f"{name.replace('_', chr(92)+'_')} & " f'{m["ego_alert_recall"]:.3f} & ' f'{m["non_ego_noalert"]:.3f} & ' f'{m["safe_neg_silent"]:.3f} & ' f'{m["safe_neg_alert_leak"]:.3f} & ' f'\\textbf{{{m["policy_score"]:.3f}}} \\\\' ) lines += [r"\hline", r"\end{tabular}", r"\end{table}"] return "\n".join(lines) def main(): parser = argparse.ArgumentParser("paper_eval") parser.add_argument("--sft_checkpoint", required=True) parser.add_argument("--label_dir", default="data/policy_labels") parser.add_argument("--belief_cache_dir", default=None) parser.add_argument("--split", default="val") parser.add_argument("--models", nargs="+", required=True, help="Pairs of: ...") parser.add_argument("--output_dir", default="eval_results/paper_tables") args = parser.parse_args() if len(args.models) % 2 != 0: raise ValueError("--models must be pairs of ") model_pairs = [(args.models[i], args.models[i+1]) for i in range(0, len(args.models), 2)] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) # ── shared data loader ──────────────────────────────────────────────────── cache_path = None if args.belief_cache_dir: p = Path(args.belief_cache_dir) / f"{args.split}.pt" if p.exists(): cache_path = p ds = PolicyDataset( manifests=[Path(args.label_dir) / f"{args.split}.json"], split=args.split, belief_cache_path=cache_path, ) loader = DataLoader(ds, batch_size=512, shuffle=False, num_workers=4, collate_fn=policy_collate_fn) logger.info(f"Val set: {len(ds)} samples") # ── evaluate each model ─────────────────────────────────────────────────── all_results: Dict[str, dict] = {} for model_name, ckpt_dir in model_pairs: logger.info(f"\n{'='*50}") logger.info(f"Evaluating: {model_name}") logger.info(f" Checkpoint: {ckpt_dir}") model = PolicyModel(args.sft_checkpoint, use_bf16=True) model.load_policy_checkpoint(ckpt_dir) merge = "alert" if "binary" in model_name else None m = evaluate_model(model, loader, device, merge_observe=merge) all_results[model_name] = m logger.info(f" policy_score={m['policy_score']:.4f} " f"recall={m['ego_alert_recall']:.4f} " f"FA={m['safe_neg_alert_leak']:.4f}") # Free GPU memory del model torch.cuda.empty_cache() # ── print tables ────────────────────────────────────────────────────────── print("\n" + "="*60) print("TABLE 1: Per-class classification metrics") print("="*60) # ASCII version header = f"{'Model':<30} {'':>3} {'SILENT':>8} {'':>3} {'OBSERVE':>9} {'':>3} {'ALERT':>7}" print(header) print(f"{'':>30} P R F1 P R F1 P R F1") for name, m in all_results.items(): rpt = m["classification_report"] row = f"{name:<30}" for cls in ["SILENT", "OBSERVE", "ALERT"]: if cls in rpt: row += f" {rpt[cls]['precision']:.2f} {rpt[cls]['recall']:.2f} {rpt[cls]['f1-score']:.2f}" else: row += " — — — " print(row) print("\n" + "="*60) print("TABLE 2: Policy score decomposition") print("="*60) print(f"{'Model':<30} {'Ego↑':>6} {'NonEgo↑':>8} {'Silent↑':>8} {'FA↓':>6} {'Score↑':>7}") for name, m in all_results.items(): print(f"{name:<30} {m['ego_alert_recall']:>6.3f} {m['non_ego_noalert']:>8.3f} " f"{m['safe_neg_silent']:>8.3f} {m['safe_neg_alert_leak']:>6.3f} " f"{m['policy_score']:>7.4f}") # ── save JSON and LaTeX ─────────────────────────────────────────────────── with open(out_dir / "all_results.json", "w") as f: json.dump({k: {kk: vv for kk, vv in v.items() if kk != "classification_report"} for k, v in all_results.items()}, f, indent=2) latex1 = format_latex_table1(all_results) latex2 = format_latex_table2(all_results) (out_dir / "table_per_class.tex").write_text(latex1) (out_dir / "table_policy_score.tex").write_text(latex2) logger.info(f"\n✅ Results saved → {out_dir}/") logger.info(f" JSON : all_results.json") logger.info(f" LaTeX: table_per_class.tex, table_policy_score.tex") if __name__ == "__main__": main()