| |
| """ |
| Paper Evaluation Script — 生成论文所需的全部指标表格 |
| |
| 对比以下模型(所有都在同一 val set 上评估): |
| 1. LKAlert-Binary (obs→alert) ← Ablation baseline |
| 2. LKAlert-v2 (focal 0.1/0.3/0.6, 43% FA) |
| 3. LKAlert-v3 (focal 0.2/0.3/0.5, fixed FA) ← 主模型 |
| |
| 输出(LaTeX + JSON): |
| Table 1: Per-class metrics (precision, recall, F1) |
| Table 2: Policy score decomposition |
| Table 3: OBSERVE lead-time advantage |
| Table 4: False alarm vs recall tradeoff |
| |
| 使用方法: |
| python -m training.Policy.paper_eval \ |
| --sft_checkpoint checkpoints/SFT/sft_v2/best \ |
| --label_dir data/policy_labels \ |
| --belief_cache_dir data/belief_cache \ |
| --models policy_warmstart_v2 checkpoints/Policy/policy_warmstart_v2/best \ |
| policy_warmstart_v3 checkpoints/Policy/policy_warmstart_v3/best \ |
| binary_obs2alert checkpoints/Policy/policy_binary_obs2alert/best \ |
| --output_dir eval_results/paper_tables |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from sklearn.metrics import classification_report, confusion_matrix |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| import sys |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) |
|
|
| from training.Policy.policy_model import PolicyModel |
| from training.Policy.policy_dataset import PolicyDataset, policy_collate_fn |
| from training.Policy.warm_start_trainer import compute_policy_score |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("Policy.paper_eval") |
|
|
| ACTION_NAMES = {0: "SILENT", 1: "OBSERVE", 2: "ALERT"} |
|
|
|
|
| @torch.no_grad() |
| def evaluate_model( |
| model: PolicyModel, |
| loader: DataLoader, |
| device: torch.device, |
| merge_observe: str = None, |
| ) -> dict: |
| """Full evaluation: per-class metrics + policy score.""" |
| model.eval() |
|
|
| all_true, all_pred = [], [] |
| all_probs = [] |
| categories, tta_raws = [], [] |
|
|
| for batch in loader: |
| if "beliefs" in batch: |
| logits = model.forward_cached( |
| batch["beliefs"].to(device), |
| batch["tta_means"].to(device), |
| batch["tta_vars"].to(device), |
| ) |
| else: |
| logits = model(batch["images"], batch["metadata"]) |
|
|
| probs = F.softmax(logits, dim=-1).cpu().numpy() |
| preds = logits.argmax(dim=-1).cpu().numpy() |
| trues = batch["action_labels"].numpy() |
|
|
| |
| if merge_observe == "alert": |
| trues = np.where(trues == 1, 2, trues) |
| preds = np.where(preds == 1, 2, preds) |
| elif merge_observe == "silent": |
| trues = np.where(trues == 1, 0, trues) |
| preds = np.where(preds == 1, 0, preds) |
|
|
| all_true.extend(trues.tolist()) |
| all_pred.extend(preds.tolist()) |
| all_probs.extend(probs.tolist()) |
| categories.extend(batch["categories"]) |
| tta_raws.extend(batch["tta_raws"].tolist()) |
|
|
| all_true = np.array(all_true) |
| all_pred = np.array(all_pred) |
|
|
| |
| present = sorted(set(all_true.tolist()) | set(all_pred.tolist())) |
| names = [ACTION_NAMES.get(i, str(i)) for i in present] |
| report = classification_report(all_true, all_pred, |
| labels=present, target_names=names, |
| output_dict=True, zero_division=0) |
|
|
| |
| cats = np.array(categories) |
|
|
| def _ratio(num, den): |
| return float(num / den) if den > 0 else 0.0 |
|
|
| |
| ego_mask = cats == "ego_positive" |
| ego_preds = all_pred[ego_mask] |
| ego_trues = all_true[ego_mask] |
| alert_true = ego_preds[ego_trues == 2] |
| ego_alert_recall = _ratio((alert_true == 2).sum(), len(alert_true)) |
|
|
| |
| ne_mask = cats == "non_ego" |
| ne_preds = all_pred[ne_mask] |
| non_ego_noalert_rate = _ratio((ne_preds != 2).sum(), len(ne_preds)) |
|
|
| |
| sn_mask = cats == "safe_neg" |
| sn_preds = all_pred[sn_mask] |
| safe_neg_silent_rate = _ratio((sn_preds == 0).sum(), len(sn_preds)) |
| safe_neg_alert_leak = _ratio((sn_preds == 2).sum(), len(sn_preds)) |
|
|
| |
| policy_score = compute_policy_score( |
| ego_alert_recall = ego_alert_recall, |
| safe_neg_silent_rate = safe_neg_silent_rate, |
| safe_neg_alert_rate = safe_neg_alert_leak, |
| ) |
|
|
| |
| tta_arr = np.array(tta_raws) |
| observe_by_tta = {} |
| for lo in range(0, 10): |
| mask = ego_mask & (tta_arr >= lo) & (tta_arr < lo + 1) |
| if mask.sum() > 0: |
| observe_by_tta[f"{lo}-{lo+1}s"] = float(np.mean(all_pred[mask] == 1)) |
|
|
| return { |
| "classification_report": report, |
| "policy_score": policy_score, |
| "ego_alert_recall": ego_alert_recall, |
| "non_ego_noalert": non_ego_noalert_rate, |
| "safe_neg_silent": safe_neg_silent_rate, |
| "safe_neg_alert_leak": safe_neg_alert_leak, |
| "observe_by_tta": observe_by_tta, |
| "n_samples": len(all_true), |
| "label_dist": {ACTION_NAMES[k]: int((all_pred == k).sum()) |
| for k in range(3)}, |
| } |
|
|
|
|
| def format_latex_table1(results: Dict[str, dict]) -> str: |
| """Per-class precision/recall/F1 comparison table.""" |
| lines = [ |
| r"\begin{table}[h]", |
| r"\centering", |
| r"\caption{Per-class Classification Performance (val set)}", |
| r"\begin{tabular}{l|ccc|ccc|ccc}", |
| r"\hline", |
| r"Model & \multicolumn{3}{c|}{SILENT} & \multicolumn{3}{c|}{OBSERVE} & \multicolumn{3}{c}{ALERT} \\", |
| r" & P & R & F1 & P & R & F1 & P & R & F1 \\", |
| r"\hline", |
| ] |
| for name, m in results.items(): |
| rpt = m["classification_report"] |
| row = [name.replace("_", r"\_")] |
| for cls in ["SILENT", "OBSERVE", "ALERT"]: |
| if cls in rpt: |
| row += [f'{rpt[cls]["precision"]:.3f}', |
| f'{rpt[cls]["recall"]:.3f}', |
| f'{rpt[cls]["f1-score"]:.3f}'] |
| else: |
| row += ["—", "—", "—"] |
| lines.append(" & ".join(row) + r" \\") |
| lines += [r"\hline", r"\end{tabular}", r"\end{table}"] |
| return "\n".join(lines) |
|
|
|
|
| def format_latex_table2(results: Dict[str, dict]) -> str: |
| """Policy score decomposition table.""" |
| lines = [ |
| r"\begin{table}[h]", |
| r"\centering", |
| r"\caption{Policy Score Decomposition}", |
| r"\begin{tabular}{l|cccc|c}", |
| r"\hline", |
| r"Model & Ego-Alert↑ & Non-Ego-NoAlert↑ & Safe-Silent↑ & False-Alarm↓ & Policy Score↑ \\", |
| r"\hline", |
| ] |
| for name, m in results.items(): |
| lines.append( |
| f"{name.replace('_', chr(92)+'_')} & " |
| f'{m["ego_alert_recall"]:.3f} & ' |
| f'{m["non_ego_noalert"]:.3f} & ' |
| f'{m["safe_neg_silent"]:.3f} & ' |
| f'{m["safe_neg_alert_leak"]:.3f} & ' |
| f'\\textbf{{{m["policy_score"]:.3f}}} \\\\' |
| ) |
| lines += [r"\hline", r"\end{tabular}", r"\end{table}"] |
| return "\n".join(lines) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser("paper_eval") |
| parser.add_argument("--sft_checkpoint", required=True) |
| parser.add_argument("--label_dir", default="data/policy_labels") |
| parser.add_argument("--belief_cache_dir", default=None) |
| parser.add_argument("--split", default="val") |
| parser.add_argument("--models", nargs="+", required=True, |
| help="Pairs of: <name> <checkpoint_dir> ...") |
| parser.add_argument("--output_dir", default="eval_results/paper_tables") |
| args = parser.parse_args() |
|
|
| if len(args.models) % 2 != 0: |
| raise ValueError("--models must be pairs of <name> <checkpoint_dir>") |
| model_pairs = [(args.models[i], args.models[i+1]) |
| for i in range(0, len(args.models), 2)] |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| out_dir = Path(args.output_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| cache_path = None |
| if args.belief_cache_dir: |
| p = Path(args.belief_cache_dir) / f"{args.split}.pt" |
| if p.exists(): |
| cache_path = p |
|
|
| ds = PolicyDataset( |
| manifests=[Path(args.label_dir) / f"{args.split}.json"], |
| split=args.split, |
| belief_cache_path=cache_path, |
| ) |
| loader = DataLoader(ds, batch_size=512, shuffle=False, |
| num_workers=4, collate_fn=policy_collate_fn) |
| logger.info(f"Val set: {len(ds)} samples") |
|
|
| |
| all_results: Dict[str, dict] = {} |
|
|
| for model_name, ckpt_dir in model_pairs: |
| logger.info(f"\n{'='*50}") |
| logger.info(f"Evaluating: {model_name}") |
| logger.info(f" Checkpoint: {ckpt_dir}") |
|
|
| model = PolicyModel(args.sft_checkpoint, use_bf16=True) |
| model.load_policy_checkpoint(ckpt_dir) |
|
|
| merge = "alert" if "binary" in model_name else None |
| m = evaluate_model(model, loader, device, merge_observe=merge) |
| all_results[model_name] = m |
|
|
| logger.info(f" policy_score={m['policy_score']:.4f} " |
| f"recall={m['ego_alert_recall']:.4f} " |
| f"FA={m['safe_neg_alert_leak']:.4f}") |
|
|
| |
| del model |
| torch.cuda.empty_cache() |
|
|
| |
| print("\n" + "="*60) |
| print("TABLE 1: Per-class classification metrics") |
| print("="*60) |
| |
| header = f"{'Model':<30} {'':>3} {'SILENT':>8} {'':>3} {'OBSERVE':>9} {'':>3} {'ALERT':>7}" |
| print(header) |
| print(f"{'':>30} P R F1 P R F1 P R F1") |
| for name, m in all_results.items(): |
| rpt = m["classification_report"] |
| row = f"{name:<30}" |
| for cls in ["SILENT", "OBSERVE", "ALERT"]: |
| if cls in rpt: |
| row += f" {rpt[cls]['precision']:.2f} {rpt[cls]['recall']:.2f} {rpt[cls]['f1-score']:.2f}" |
| else: |
| row += " — — — " |
| print(row) |
|
|
| print("\n" + "="*60) |
| print("TABLE 2: Policy score decomposition") |
| print("="*60) |
| print(f"{'Model':<30} {'Ego↑':>6} {'NonEgo↑':>8} {'Silent↑':>8} {'FA↓':>6} {'Score↑':>7}") |
| for name, m in all_results.items(): |
| print(f"{name:<30} {m['ego_alert_recall']:>6.3f} {m['non_ego_noalert']:>8.3f} " |
| f"{m['safe_neg_silent']:>8.3f} {m['safe_neg_alert_leak']:>6.3f} " |
| f"{m['policy_score']:>7.4f}") |
|
|
| |
| with open(out_dir / "all_results.json", "w") as f: |
| json.dump({k: {kk: vv for kk, vv in v.items() |
| if kk != "classification_report"} |
| for k, v in all_results.items()}, f, indent=2) |
|
|
| latex1 = format_latex_table1(all_results) |
| latex2 = format_latex_table2(all_results) |
| (out_dir / "table_per_class.tex").write_text(latex1) |
| (out_dir / "table_policy_score.tex").write_text(latex2) |
|
|
| logger.info(f"\n✅ Results saved → {out_dir}/") |
| logger.info(f" JSON : all_results.json") |
| logger.info(f" LaTeX: table_per_class.tex, table_policy_score.tex") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|