VLAlert / training /SFT /evaluate.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
11.3 kB
#!/usr/bin/env python3
"""
Standalone evaluation script for SFT checkpoints.
Usage
-----
python -m training.SFT.evaluate \
--checkpoint /path/to/checkpoints/SFT/sft_v2/best \
--manifest_dir /path/to/data/sft_manifests \
[--split val] [--batch_size 4] [--output_json results.json]
"""
from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
from typing import Dict, List
import numpy as np
import torch
from torch.amp import autocast
from torch.utils.data import DataLoader
from tqdm import tqdm
from .dataset import SFTDataset, sft_collate_fn
from .trainer import SFTModel, compute_sft_loss, load_sft_heads, _is_sft_ckpt_dir
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("SFT.evaluate")
def _build_prompt(metadata: dict) -> str:
parts = []
if metadata.get("weather"): parts.append(f"Weather: {metadata['weather']}")
if metadata.get("road_type"): parts.append(f"Road: {metadata['road_type']}")
if metadata.get("time_of_day"): parts.append(f"Time: {metadata['time_of_day']}")
ctx = ", ".join(parts) or "Urban driving"
return (
f"Analyze this driving sequence.\n"
f"Context: {ctx}\n"
f"Estimate the time to potential collision. Output a single number in seconds."
)
def evaluate_checkpoint(
model: SFTModel,
loader: DataLoader,
amp_dtype: torch.dtype = torch.bfloat16,
nll_weight: float = 0.5,
) -> Dict[str, float]:
model.eval()
SYSTEM = "You are a driving safety AI analyzing dashcam footage for collision risk."
total_loss = 0.0
n = 0
all_hazard_prob: List[np.ndarray] = []
all_hazard_label: List[np.ndarray] = []
all_is_ego_pos: List[np.ndarray] = []
all_is_noneego: List[np.ndarray] = []
all_tta_pred: List[np.ndarray] = []
all_tta_label: List[np.ndarray] = []
all_tta_std: List[np.ndarray] = []
all_is_censored: List[np.ndarray] = []
proc = model.processor
apply_chat = (
proc.apply_chat_template
if hasattr(proc, "apply_chat_template")
else proc.tokenizer.apply_chat_template
)
with torch.no_grad():
for batch in tqdm(loader, desc="Evaluating", ncols=70):
images = batch["images"]
texts = []
for i in range(len(batch["video_ids"])):
frames = images[i]
content = [{"type": "image"} for _ in range(len(frames))]
content.append({"type": "text", "text": _build_prompt(batch["metadata"][i])})
msgs = [{"role": "system", "content": SYSTEM}, {"role": "user", "content": content}]
texts.append(apply_chat(msgs, tokenize=False, add_generation_prompt=False))
inputs = proc(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
dev = model.device
t = {
"tta_labels": batch["tta_labels"].to(dev),
"hazard_labels": batch["hazard_labels"].to(dev),
"hazard_weights": batch["hazard_weights"].to(dev),
"is_ego_positive": batch["is_ego_positive"].to(dev),
"is_censored": batch["is_censored"].to(dev),
}
is_noneego = batch.get("is_non_ego", torch.zeros(len(batch["video_ids"]), dtype=torch.bool))
with autocast(device_type="cuda", dtype=amp_dtype, enabled=True):
out = model(inputs)
loss, _ = compute_sft_loss(
hazard_logit=out["hazard_logit"],
tta_mean=out["tta_mean"],
tta_logvar=out["tta_logvar"],
hazard_label=t["hazard_labels"],
hazard_weight=t["hazard_weights"],
is_ego_positive=t["is_ego_positive"],
is_censored=t["is_censored"],
tta_label=t["tta_labels"],
nll_weight=nll_weight,
)
total_loss += float(loss.item())
n += 1
all_hazard_prob.append(out["hazard_prob"].detach().float().cpu().numpy())
all_hazard_label.append(t["hazard_labels"].detach().float().cpu().numpy())
all_is_ego_pos.append(t["is_ego_positive"].cpu().numpy())
all_is_noneego.append(is_noneego.cpu().numpy())
all_tta_pred.append(out["tta_mean"].detach().float().cpu().numpy())
all_tta_label.append(t["tta_labels"].detach().float().cpu().numpy())
all_tta_std.append(torch.exp(0.5 * out["tta_logvar"].detach().float()).cpu().numpy())
all_is_censored.append(t["is_censored"].cpu().numpy())
def cat(lst, dtype=np.float32):
return np.concatenate(lst).astype(dtype) if lst else np.array([], dtype=dtype)
hp_all = cat(all_hazard_prob)
hl_all = cat(all_hazard_label)
ep_all = cat(all_is_ego_pos, bool)
ne_all = cat(all_is_noneego, bool)
pred_all = cat(all_tta_pred)
lbl_all = cat(all_tta_label)
std_all = cat(all_tta_std)
cen_all = cat(all_is_censored, bool)
# ── hazard metrics ──────────────────────────────────────────────────────
hp_bin = (hp_all > 0.5).astype(np.float32)
tp = float(((hp_bin == 1) & (hl_all == 1)).sum())
fp = float(((hp_bin == 1) & (hl_all == 0)).sum())
fn = float(((hp_bin == 0) & (hl_all == 1)).sum())
prec = tp / max(1, tp + fp)
recall = tp / max(1, tp + fn)
f1 = 2 * prec * recall / max(1e-9, prec + recall)
ne_mask = ne_all.astype(bool)
safe_neg_mask = (~ep_all) & (~ne_mask)
ne_far = float((hp_bin[ne_mask] == 1).mean()) if ne_mask.any() else 0.0
sneg_fa = float((hp_bin[safe_neg_mask] == 1).mean()) if safe_neg_mask.any() else 0.0
# ── TTA metrics (positive-observed only) ────────────────────────────────
obs_mask = ep_all & (~cen_all)
if obs_mask.any():
pos_preds = pred_all[obs_mask]
pos_labels = lbl_all[obs_mask]
pos_mae = float(np.abs(pos_preds - pos_labels).mean())
pos_rmse = float(np.sqrt(((pos_preds - pos_labels) ** 2).mean()))
low_mask = pos_labels <= 3.0
low_mae = float(np.abs(pos_preds[low_mask] - pos_labels[low_mask]).mean()) if low_mask.any() else 0.0
denom = float(((pos_labels - pos_labels.mean()) ** 2).sum()) + 1e-12
pos_r2 = float(1.0 - ((pos_preds - pos_labels) ** 2).sum() / denom)
else:
pos_mae = pos_rmse = low_mae = 10.0
pos_r2 = 0.0
ckpt_score = 0.6 * f1 - 0.4 * (pos_mae / 10.0)
metrics = {
"loss": total_loss / max(1, n),
"hazard_f1": f1,
"hazard_precision": prec,
"hazard_recall": recall,
"hazard_tp": int(tp),
"hazard_fp": int(fp),
"hazard_fn": int(fn),
"pos_tta_mae": pos_mae,
"pos_tta_rmse": pos_rmse,
"pos_tta_r2": pos_r2,
"low_tta_mae": low_mae,
"non_ego_false_alert": ne_far,
"safe_neg_false_alert": sneg_fa,
"uncertainty_mean": float(std_all.mean()) if std_all.size else 0.0,
"ckpt_score": ckpt_score,
"n_total": int(hp_all.size),
"n_ego_pos": int(ep_all.sum()),
"n_non_ego": int(ne_all.sum()),
"n_safe_neg": int(safe_neg_mask.sum()),
"n_obs": int(obs_mask.sum()),
"n_censored": int(cen_all[ep_all].sum()) if ep_all.any() else 0,
}
logger.info(
f" hazard_f1={f1:.3f} prec={prec:.3f} recall={recall:.3f}\n"
f" pos_tta_mae={pos_mae:.3f} low_tta_mae={low_mae:.3f} pos_r2={pos_r2:.3f}\n"
f" non_ego_fa={ne_far:.3f} safe_neg_fa={sneg_fa:.3f}\n"
f" ckpt_score={ckpt_score:.4f} loss={metrics['loss']:.4f}"
)
return metrics
def main():
parser = argparse.ArgumentParser("SFT checkpoint evaluation")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to SFT checkpoint dir")
parser.add_argument("--manifest_dir", type=str, default="PROJECT_ROOT/data/sft_manifests")
parser.add_argument("--split", type=str, default="val", choices=["val", "train", "test_public"])
parser.add_argument("--model_name", type=str, default="PROJECT_ROOT/models/Qwen2.5-VL-3B-Instruct")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--output_json", type=str, default=None)
args = parser.parse_args()
ckpt_dir = Path(args.checkpoint)
if not _is_sft_ckpt_dir(ckpt_dir):
raise RuntimeError(f"Not a valid SFT checkpoint: {ckpt_dir}")
manifest_dir = Path(args.manifest_dir)
if args.split == "val":
manifests = [
manifest_dir / "nexar_val.json",
manifest_dir / "dada_pos_val.json",
manifest_dir / "dada_noneego_val.json",
]
elif args.split == "train":
manifests = [
manifest_dir / "nexar_train.json",
manifest_dir / "dada_pos_train.json",
manifest_dir / "dada_noneego_train.json",
manifest_dir / "dada_neg_train.json",
]
else:
manifests = [manifest_dir / "nexar_test_public.json"]
manifests = [m for m in manifests if m.exists()]
if not manifests:
raise RuntimeError(f"No manifests found for split '{args.split}' in {manifest_dir}")
logger.info(f"Manifests: {[m.name for m in manifests]}")
dataset = SFTDataset(
manifests=manifests,
split="val" if args.split != "train" else "train",
)
loader = DataLoader(
dataset, batch_size=args.batch_size, shuffle=False,
collate_fn=sft_collate_fn, num_workers=4, pin_memory=True,
)
with open(ckpt_dir / "config.json") as f:
cfg = json.load(f)
model_name = cfg.get("model_name", args.model_name)
logger.info(f"Loading model: {model_name}")
model = SFTModel(
model_name=model_name,
pretrained_lora_path=str(ckpt_dir / "vlm_lora"),
belief_strategy=cfg.get("belief_strategy", "mean_pool"),
tta_intermediate_dim=cfg.get("tta_intermediate_dim", 512),
use_lora=True,
use_bf16=True,
device="auto",
)
load_sft_heads(model, ckpt_dir)
logger.info(f"Evaluating {ckpt_dir.name} split={args.split} n={len(dataset)}")
metrics = evaluate_checkpoint(model, loader)
print("\n=== Evaluation Results ===")
for k, v in metrics.items():
if isinstance(v, float):
print(f" {k:30s} {v:.4f}")
else:
print(f" {k:30s} {v}")
if args.output_json:
out_path = Path(args.output_json)
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "w") as f:
json.dump({"checkpoint": str(ckpt_dir), "split": args.split, "metrics": metrics}, f, indent=2)
logger.info(f"Results written to {out_path}")
if __name__ == "__main__":
main()