| |
| """ |
| Standalone evaluation script for SFT checkpoints. |
| |
| Usage |
| ----- |
| python -m training.SFT.evaluate \ |
| --checkpoint /path/to/checkpoints/SFT/sft_v2/best \ |
| --manifest_dir /path/to/data/sft_manifests \ |
| [--split val] [--batch_size 4] [--output_json results.json] |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import numpy as np |
| import torch |
| from torch.amp import autocast |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from .dataset import SFTDataset, sft_collate_fn |
| from .trainer import SFTModel, compute_sft_loss, load_sft_heads, _is_sft_ckpt_dir |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("SFT.evaluate") |
|
|
|
|
| def _build_prompt(metadata: dict) -> str: |
| parts = [] |
| if metadata.get("weather"): parts.append(f"Weather: {metadata['weather']}") |
| if metadata.get("road_type"): parts.append(f"Road: {metadata['road_type']}") |
| if metadata.get("time_of_day"): parts.append(f"Time: {metadata['time_of_day']}") |
| ctx = ", ".join(parts) or "Urban driving" |
| return ( |
| f"Analyze this driving sequence.\n" |
| f"Context: {ctx}\n" |
| f"Estimate the time to potential collision. Output a single number in seconds." |
| ) |
|
|
|
|
| def evaluate_checkpoint( |
| model: SFTModel, |
| loader: DataLoader, |
| amp_dtype: torch.dtype = torch.bfloat16, |
| nll_weight: float = 0.5, |
| ) -> Dict[str, float]: |
| model.eval() |
| SYSTEM = "You are a driving safety AI analyzing dashcam footage for collision risk." |
|
|
| total_loss = 0.0 |
| n = 0 |
| all_hazard_prob: List[np.ndarray] = [] |
| all_hazard_label: List[np.ndarray] = [] |
| all_is_ego_pos: List[np.ndarray] = [] |
| all_is_noneego: List[np.ndarray] = [] |
| all_tta_pred: List[np.ndarray] = [] |
| all_tta_label: List[np.ndarray] = [] |
| all_tta_std: List[np.ndarray] = [] |
| all_is_censored: List[np.ndarray] = [] |
|
|
| proc = model.processor |
| apply_chat = ( |
| proc.apply_chat_template |
| if hasattr(proc, "apply_chat_template") |
| else proc.tokenizer.apply_chat_template |
| ) |
|
|
| with torch.no_grad(): |
| for batch in tqdm(loader, desc="Evaluating", ncols=70): |
| images = batch["images"] |
| texts = [] |
| for i in range(len(batch["video_ids"])): |
| frames = images[i] |
| content = [{"type": "image"} for _ in range(len(frames))] |
| content.append({"type": "text", "text": _build_prompt(batch["metadata"][i])}) |
| msgs = [{"role": "system", "content": SYSTEM}, {"role": "user", "content": content}] |
| texts.append(apply_chat(msgs, tokenize=False, add_generation_prompt=False)) |
|
|
| inputs = proc(text=texts, images=images, return_tensors="pt", padding=True, truncation=True) |
|
|
| dev = model.device |
| t = { |
| "tta_labels": batch["tta_labels"].to(dev), |
| "hazard_labels": batch["hazard_labels"].to(dev), |
| "hazard_weights": batch["hazard_weights"].to(dev), |
| "is_ego_positive": batch["is_ego_positive"].to(dev), |
| "is_censored": batch["is_censored"].to(dev), |
| } |
| is_noneego = batch.get("is_non_ego", torch.zeros(len(batch["video_ids"]), dtype=torch.bool)) |
|
|
| with autocast(device_type="cuda", dtype=amp_dtype, enabled=True): |
| out = model(inputs) |
| loss, _ = compute_sft_loss( |
| hazard_logit=out["hazard_logit"], |
| tta_mean=out["tta_mean"], |
| tta_logvar=out["tta_logvar"], |
| hazard_label=t["hazard_labels"], |
| hazard_weight=t["hazard_weights"], |
| is_ego_positive=t["is_ego_positive"], |
| is_censored=t["is_censored"], |
| tta_label=t["tta_labels"], |
| nll_weight=nll_weight, |
| ) |
|
|
| total_loss += float(loss.item()) |
| n += 1 |
|
|
| all_hazard_prob.append(out["hazard_prob"].detach().float().cpu().numpy()) |
| all_hazard_label.append(t["hazard_labels"].detach().float().cpu().numpy()) |
| all_is_ego_pos.append(t["is_ego_positive"].cpu().numpy()) |
| all_is_noneego.append(is_noneego.cpu().numpy()) |
| all_tta_pred.append(out["tta_mean"].detach().float().cpu().numpy()) |
| all_tta_label.append(t["tta_labels"].detach().float().cpu().numpy()) |
| all_tta_std.append(torch.exp(0.5 * out["tta_logvar"].detach().float()).cpu().numpy()) |
| all_is_censored.append(t["is_censored"].cpu().numpy()) |
|
|
| def cat(lst, dtype=np.float32): |
| return np.concatenate(lst).astype(dtype) if lst else np.array([], dtype=dtype) |
|
|
| hp_all = cat(all_hazard_prob) |
| hl_all = cat(all_hazard_label) |
| ep_all = cat(all_is_ego_pos, bool) |
| ne_all = cat(all_is_noneego, bool) |
| pred_all = cat(all_tta_pred) |
| lbl_all = cat(all_tta_label) |
| std_all = cat(all_tta_std) |
| cen_all = cat(all_is_censored, bool) |
|
|
| |
| hp_bin = (hp_all > 0.5).astype(np.float32) |
| tp = float(((hp_bin == 1) & (hl_all == 1)).sum()) |
| fp = float(((hp_bin == 1) & (hl_all == 0)).sum()) |
| fn = float(((hp_bin == 0) & (hl_all == 1)).sum()) |
| prec = tp / max(1, tp + fp) |
| recall = tp / max(1, tp + fn) |
| f1 = 2 * prec * recall / max(1e-9, prec + recall) |
|
|
| ne_mask = ne_all.astype(bool) |
| safe_neg_mask = (~ep_all) & (~ne_mask) |
| ne_far = float((hp_bin[ne_mask] == 1).mean()) if ne_mask.any() else 0.0 |
| sneg_fa = float((hp_bin[safe_neg_mask] == 1).mean()) if safe_neg_mask.any() else 0.0 |
|
|
| |
| obs_mask = ep_all & (~cen_all) |
| if obs_mask.any(): |
| pos_preds = pred_all[obs_mask] |
| pos_labels = lbl_all[obs_mask] |
| pos_mae = float(np.abs(pos_preds - pos_labels).mean()) |
| pos_rmse = float(np.sqrt(((pos_preds - pos_labels) ** 2).mean())) |
| low_mask = pos_labels <= 3.0 |
| low_mae = float(np.abs(pos_preds[low_mask] - pos_labels[low_mask]).mean()) if low_mask.any() else 0.0 |
| denom = float(((pos_labels - pos_labels.mean()) ** 2).sum()) + 1e-12 |
| pos_r2 = float(1.0 - ((pos_preds - pos_labels) ** 2).sum() / denom) |
| else: |
| pos_mae = pos_rmse = low_mae = 10.0 |
| pos_r2 = 0.0 |
|
|
| ckpt_score = 0.6 * f1 - 0.4 * (pos_mae / 10.0) |
|
|
| metrics = { |
| "loss": total_loss / max(1, n), |
| "hazard_f1": f1, |
| "hazard_precision": prec, |
| "hazard_recall": recall, |
| "hazard_tp": int(tp), |
| "hazard_fp": int(fp), |
| "hazard_fn": int(fn), |
| "pos_tta_mae": pos_mae, |
| "pos_tta_rmse": pos_rmse, |
| "pos_tta_r2": pos_r2, |
| "low_tta_mae": low_mae, |
| "non_ego_false_alert": ne_far, |
| "safe_neg_false_alert": sneg_fa, |
| "uncertainty_mean": float(std_all.mean()) if std_all.size else 0.0, |
| "ckpt_score": ckpt_score, |
| "n_total": int(hp_all.size), |
| "n_ego_pos": int(ep_all.sum()), |
| "n_non_ego": int(ne_all.sum()), |
| "n_safe_neg": int(safe_neg_mask.sum()), |
| "n_obs": int(obs_mask.sum()), |
| "n_censored": int(cen_all[ep_all].sum()) if ep_all.any() else 0, |
| } |
|
|
| logger.info( |
| f" hazard_f1={f1:.3f} prec={prec:.3f} recall={recall:.3f}\n" |
| f" pos_tta_mae={pos_mae:.3f} low_tta_mae={low_mae:.3f} pos_r2={pos_r2:.3f}\n" |
| f" non_ego_fa={ne_far:.3f} safe_neg_fa={sneg_fa:.3f}\n" |
| f" ckpt_score={ckpt_score:.4f} loss={metrics['loss']:.4f}" |
| ) |
| return metrics |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser("SFT checkpoint evaluation") |
| parser.add_argument("--checkpoint", type=str, required=True, help="Path to SFT checkpoint dir") |
| parser.add_argument("--manifest_dir", type=str, default="PROJECT_ROOT/data/sft_manifests") |
| parser.add_argument("--split", type=str, default="val", choices=["val", "train", "test_public"]) |
| parser.add_argument("--model_name", type=str, default="PROJECT_ROOT/models/Qwen2.5-VL-3B-Instruct") |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--output_json", type=str, default=None) |
| args = parser.parse_args() |
|
|
| ckpt_dir = Path(args.checkpoint) |
| if not _is_sft_ckpt_dir(ckpt_dir): |
| raise RuntimeError(f"Not a valid SFT checkpoint: {ckpt_dir}") |
|
|
| manifest_dir = Path(args.manifest_dir) |
| if args.split == "val": |
| manifests = [ |
| manifest_dir / "nexar_val.json", |
| manifest_dir / "dada_pos_val.json", |
| manifest_dir / "dada_noneego_val.json", |
| ] |
| elif args.split == "train": |
| manifests = [ |
| manifest_dir / "nexar_train.json", |
| manifest_dir / "dada_pos_train.json", |
| manifest_dir / "dada_noneego_train.json", |
| manifest_dir / "dada_neg_train.json", |
| ] |
| else: |
| manifests = [manifest_dir / "nexar_test_public.json"] |
|
|
| manifests = [m for m in manifests if m.exists()] |
| if not manifests: |
| raise RuntimeError(f"No manifests found for split '{args.split}' in {manifest_dir}") |
|
|
| logger.info(f"Manifests: {[m.name for m in manifests]}") |
| dataset = SFTDataset( |
| manifests=manifests, |
| split="val" if args.split != "train" else "train", |
| ) |
| loader = DataLoader( |
| dataset, batch_size=args.batch_size, shuffle=False, |
| collate_fn=sft_collate_fn, num_workers=4, pin_memory=True, |
| ) |
|
|
| with open(ckpt_dir / "config.json") as f: |
| cfg = json.load(f) |
|
|
| model_name = cfg.get("model_name", args.model_name) |
| logger.info(f"Loading model: {model_name}") |
| model = SFTModel( |
| model_name=model_name, |
| pretrained_lora_path=str(ckpt_dir / "vlm_lora"), |
| belief_strategy=cfg.get("belief_strategy", "mean_pool"), |
| tta_intermediate_dim=cfg.get("tta_intermediate_dim", 512), |
| use_lora=True, |
| use_bf16=True, |
| device="auto", |
| ) |
| load_sft_heads(model, ckpt_dir) |
|
|
| logger.info(f"Evaluating {ckpt_dir.name} split={args.split} n={len(dataset)}") |
| metrics = evaluate_checkpoint(model, loader) |
|
|
| print("\n=== Evaluation Results ===") |
| for k, v in metrics.items(): |
| if isinstance(v, float): |
| print(f" {k:30s} {v:.4f}") |
| else: |
| print(f" {k:30s} {v}") |
|
|
| if args.output_json: |
| out_path = Path(args.output_json) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(out_path, "w") as f: |
| json.dump({"checkpoint": str(ckpt_dir), "split": args.split, "metrics": metrics}, f, indent=2) |
| logger.info(f"Results written to {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|