#!/usr/bin/env python3 """ Standalone evaluation script for SFT checkpoints. Usage ----- python -m training.SFT.evaluate \ --checkpoint /path/to/checkpoints/SFT/sft_v2/best \ --manifest_dir /path/to/data/sft_manifests \ [--split val] [--batch_size 4] [--output_json results.json] """ from __future__ import annotations import argparse import json import logging from pathlib import Path from typing import Dict, List import numpy as np import torch from torch.amp import autocast from torch.utils.data import DataLoader from tqdm import tqdm from .dataset import SFTDataset, sft_collate_fn from .trainer import SFTModel, compute_sft_loss, load_sft_heads, _is_sft_ckpt_dir logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("SFT.evaluate") def _build_prompt(metadata: dict) -> str: parts = [] if metadata.get("weather"): parts.append(f"Weather: {metadata['weather']}") if metadata.get("road_type"): parts.append(f"Road: {metadata['road_type']}") if metadata.get("time_of_day"): parts.append(f"Time: {metadata['time_of_day']}") ctx = ", ".join(parts) or "Urban driving" return ( f"Analyze this driving sequence.\n" f"Context: {ctx}\n" f"Estimate the time to potential collision. Output a single number in seconds." ) def evaluate_checkpoint( model: SFTModel, loader: DataLoader, amp_dtype: torch.dtype = torch.bfloat16, nll_weight: float = 0.5, ) -> Dict[str, float]: model.eval() SYSTEM = "You are a driving safety AI analyzing dashcam footage for collision risk." total_loss = 0.0 n = 0 all_hazard_prob: List[np.ndarray] = [] all_hazard_label: List[np.ndarray] = [] all_is_ego_pos: List[np.ndarray] = [] all_is_noneego: List[np.ndarray] = [] all_tta_pred: List[np.ndarray] = [] all_tta_label: List[np.ndarray] = [] all_tta_std: List[np.ndarray] = [] all_is_censored: List[np.ndarray] = [] proc = model.processor apply_chat = ( proc.apply_chat_template if hasattr(proc, "apply_chat_template") else proc.tokenizer.apply_chat_template ) with torch.no_grad(): for batch in tqdm(loader, desc="Evaluating", ncols=70): images = batch["images"] texts = [] for i in range(len(batch["video_ids"])): frames = images[i] content = [{"type": "image"} for _ in range(len(frames))] content.append({"type": "text", "text": _build_prompt(batch["metadata"][i])}) msgs = [{"role": "system", "content": SYSTEM}, {"role": "user", "content": content}] texts.append(apply_chat(msgs, tokenize=False, add_generation_prompt=False)) inputs = proc(text=texts, images=images, return_tensors="pt", padding=True, truncation=True) dev = model.device t = { "tta_labels": batch["tta_labels"].to(dev), "hazard_labels": batch["hazard_labels"].to(dev), "hazard_weights": batch["hazard_weights"].to(dev), "is_ego_positive": batch["is_ego_positive"].to(dev), "is_censored": batch["is_censored"].to(dev), } is_noneego = batch.get("is_non_ego", torch.zeros(len(batch["video_ids"]), dtype=torch.bool)) with autocast(device_type="cuda", dtype=amp_dtype, enabled=True): out = model(inputs) loss, _ = compute_sft_loss( hazard_logit=out["hazard_logit"], tta_mean=out["tta_mean"], tta_logvar=out["tta_logvar"], hazard_label=t["hazard_labels"], hazard_weight=t["hazard_weights"], is_ego_positive=t["is_ego_positive"], is_censored=t["is_censored"], tta_label=t["tta_labels"], nll_weight=nll_weight, ) total_loss += float(loss.item()) n += 1 all_hazard_prob.append(out["hazard_prob"].detach().float().cpu().numpy()) all_hazard_label.append(t["hazard_labels"].detach().float().cpu().numpy()) all_is_ego_pos.append(t["is_ego_positive"].cpu().numpy()) all_is_noneego.append(is_noneego.cpu().numpy()) all_tta_pred.append(out["tta_mean"].detach().float().cpu().numpy()) all_tta_label.append(t["tta_labels"].detach().float().cpu().numpy()) all_tta_std.append(torch.exp(0.5 * out["tta_logvar"].detach().float()).cpu().numpy()) all_is_censored.append(t["is_censored"].cpu().numpy()) def cat(lst, dtype=np.float32): return np.concatenate(lst).astype(dtype) if lst else np.array([], dtype=dtype) hp_all = cat(all_hazard_prob) hl_all = cat(all_hazard_label) ep_all = cat(all_is_ego_pos, bool) ne_all = cat(all_is_noneego, bool) pred_all = cat(all_tta_pred) lbl_all = cat(all_tta_label) std_all = cat(all_tta_std) cen_all = cat(all_is_censored, bool) # ── hazard metrics ────────────────────────────────────────────────────── hp_bin = (hp_all > 0.5).astype(np.float32) tp = float(((hp_bin == 1) & (hl_all == 1)).sum()) fp = float(((hp_bin == 1) & (hl_all == 0)).sum()) fn = float(((hp_bin == 0) & (hl_all == 1)).sum()) prec = tp / max(1, tp + fp) recall = tp / max(1, tp + fn) f1 = 2 * prec * recall / max(1e-9, prec + recall) ne_mask = ne_all.astype(bool) safe_neg_mask = (~ep_all) & (~ne_mask) ne_far = float((hp_bin[ne_mask] == 1).mean()) if ne_mask.any() else 0.0 sneg_fa = float((hp_bin[safe_neg_mask] == 1).mean()) if safe_neg_mask.any() else 0.0 # ── TTA metrics (positive-observed only) ──────────────────────────────── obs_mask = ep_all & (~cen_all) if obs_mask.any(): pos_preds = pred_all[obs_mask] pos_labels = lbl_all[obs_mask] pos_mae = float(np.abs(pos_preds - pos_labels).mean()) pos_rmse = float(np.sqrt(((pos_preds - pos_labels) ** 2).mean())) low_mask = pos_labels <= 3.0 low_mae = float(np.abs(pos_preds[low_mask] - pos_labels[low_mask]).mean()) if low_mask.any() else 0.0 denom = float(((pos_labels - pos_labels.mean()) ** 2).sum()) + 1e-12 pos_r2 = float(1.0 - ((pos_preds - pos_labels) ** 2).sum() / denom) else: pos_mae = pos_rmse = low_mae = 10.0 pos_r2 = 0.0 ckpt_score = 0.6 * f1 - 0.4 * (pos_mae / 10.0) metrics = { "loss": total_loss / max(1, n), "hazard_f1": f1, "hazard_precision": prec, "hazard_recall": recall, "hazard_tp": int(tp), "hazard_fp": int(fp), "hazard_fn": int(fn), "pos_tta_mae": pos_mae, "pos_tta_rmse": pos_rmse, "pos_tta_r2": pos_r2, "low_tta_mae": low_mae, "non_ego_false_alert": ne_far, "safe_neg_false_alert": sneg_fa, "uncertainty_mean": float(std_all.mean()) if std_all.size else 0.0, "ckpt_score": ckpt_score, "n_total": int(hp_all.size), "n_ego_pos": int(ep_all.sum()), "n_non_ego": int(ne_all.sum()), "n_safe_neg": int(safe_neg_mask.sum()), "n_obs": int(obs_mask.sum()), "n_censored": int(cen_all[ep_all].sum()) if ep_all.any() else 0, } logger.info( f" hazard_f1={f1:.3f} prec={prec:.3f} recall={recall:.3f}\n" f" pos_tta_mae={pos_mae:.3f} low_tta_mae={low_mae:.3f} pos_r2={pos_r2:.3f}\n" f" non_ego_fa={ne_far:.3f} safe_neg_fa={sneg_fa:.3f}\n" f" ckpt_score={ckpt_score:.4f} loss={metrics['loss']:.4f}" ) return metrics def main(): parser = argparse.ArgumentParser("SFT checkpoint evaluation") parser.add_argument("--checkpoint", type=str, required=True, help="Path to SFT checkpoint dir") parser.add_argument("--manifest_dir", type=str, default="PROJECT_ROOT/data/sft_manifests") parser.add_argument("--split", type=str, default="val", choices=["val", "train", "test_public"]) parser.add_argument("--model_name", type=str, default="PROJECT_ROOT/models/Qwen2.5-VL-3B-Instruct") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--output_json", type=str, default=None) args = parser.parse_args() ckpt_dir = Path(args.checkpoint) if not _is_sft_ckpt_dir(ckpt_dir): raise RuntimeError(f"Not a valid SFT checkpoint: {ckpt_dir}") manifest_dir = Path(args.manifest_dir) if args.split == "val": manifests = [ manifest_dir / "nexar_val.json", manifest_dir / "dada_pos_val.json", manifest_dir / "dada_noneego_val.json", ] elif args.split == "train": manifests = [ manifest_dir / "nexar_train.json", manifest_dir / "dada_pos_train.json", manifest_dir / "dada_noneego_train.json", manifest_dir / "dada_neg_train.json", ] else: manifests = [manifest_dir / "nexar_test_public.json"] manifests = [m for m in manifests if m.exists()] if not manifests: raise RuntimeError(f"No manifests found for split '{args.split}' in {manifest_dir}") logger.info(f"Manifests: {[m.name for m in manifests]}") dataset = SFTDataset( manifests=manifests, split="val" if args.split != "train" else "train", ) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, collate_fn=sft_collate_fn, num_workers=4, pin_memory=True, ) with open(ckpt_dir / "config.json") as f: cfg = json.load(f) model_name = cfg.get("model_name", args.model_name) logger.info(f"Loading model: {model_name}") model = SFTModel( model_name=model_name, pretrained_lora_path=str(ckpt_dir / "vlm_lora"), belief_strategy=cfg.get("belief_strategy", "mean_pool"), tta_intermediate_dim=cfg.get("tta_intermediate_dim", 512), use_lora=True, use_bf16=True, device="auto", ) load_sft_heads(model, ckpt_dir) logger.info(f"Evaluating {ckpt_dir.name} split={args.split} n={len(dataset)}") metrics = evaluate_checkpoint(model, loader) print("\n=== Evaluation Results ===") for k, v in metrics.items(): if isinstance(v, float): print(f" {k:30s} {v:.4f}") else: print(f" {k:30s} {v}") if args.output_json: out_path = Path(args.output_json) out_path.parent.mkdir(parents=True, exist_ok=True) with open(out_path, "w") as f: json.dump({"checkpoint": str(ckpt_dir), "split": args.split, "metrics": metrics}, f, indent=2) logger.info(f"Results written to {out_path}") if __name__ == "__main__": main()