#!/usr/bin/env python3 """ DPO Trainer — aligns HazardHead alert timing via Direct Preference Optimization. Architecture ------------ Base: SFTModel (VLM + LoRA + BeliefAggregator + HazardHead + TTAHead) loaded from SFT best checkpoint; VLM / TTAHead / BeliefAggregator FROZEN. Trainable: HazardHead only (~2 k params) Reference: frozen copy of the initial SFT HazardHead (for DPO implicit reward) Loss ---- L = L_DPO + lambda_reg * L_reg L_DPO = -log σ(β · [(log P_θ(alert|chosen) - log P_ref(alert|chosen)) - (log P_θ(alert|rejected) - log P_ref(alert|rejected))]) L_reg = BCE(logit_chosen, 1) # keep detecting hazards in chosen windows + BCE(logit_rejected, 0) # keep suppressing hazards in rejected windows Checkpoint selection: val DPO accuracy = fraction of pairs where P_θ(alert|chosen) > P_θ(alert|rejected) Usage ----- python -m training.DPO.trainer \ --sft_checkpoint checkpoints/SFT/sft_v2/best \ --pair_dir data/dpo_pairs \ --output_dir checkpoints/DPO \ --experiment_name dpo_v1 """ from __future__ import annotations import argparse import copy import json import logging from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.amp import autocast from torch.optim import AdamW from torch.utils.data import DataLoader from tqdm import tqdm try: import wandb HAS_WANDB = True except ImportError: HAS_WANDB = False from .dataset import DPODataset, dpo_collate_fn # Import SFT infrastructure import sys sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) from training.SFT.trainer import SFTModel, load_sft_heads, _is_sft_ckpt_dir logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("DPO.trainer") SYSTEM = "You are a driving safety AI analyzing dashcam footage for collision risk." # ───────────────────────────────────────────────────────────────────────────── # Prompt builder (identical to SFT evaluate.py) # ───────────────────────────────────────────────────────────────────────────── def _build_prompt(metadata: dict) -> str: parts = [] if metadata.get("weather"): parts.append(f"Weather: {metadata['weather']}") if metadata.get("road_type"): parts.append(f"Road: {metadata['road_type']}") if metadata.get("time_of_day"): parts.append(f"Time: {metadata['time_of_day']}") ctx = ", ".join(parts) or "Urban driving" return ( f"Analyze this driving sequence.\n" f"Context: {ctx}\n" f"Estimate the time to potential collision. Output a single number in seconds." ) # ───────────────────────────────────────────────────────────────────────────── # DPO loss # ───────────────────────────────────────────────────────────────────────────── def compute_dpo_loss( logit_chosen: torch.Tensor, # [B] policy logit for chosen window logit_rejected: torch.Tensor, # [B] policy logit for rejected window ref_logit_chosen: torch.Tensor, # [B] reference logit (frozen) ref_logit_rejected: torch.Tensor, # [B] reference logit (frozen) beta: float = 0.1, ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Standard DPO loss for binary alert policy. log P(alert | x) = log σ(logit) [binary action] """ # log π_θ(alert | ·) log_pi_chosen = -F.softplus(-logit_chosen.float()) log_pi_rejected = -F.softplus(-logit_rejected.float()) # log π_ref(alert | ·) with torch.no_grad(): log_ref_chosen = -F.softplus(-ref_logit_chosen.float()) log_ref_rejected = -F.softplus(-ref_logit_rejected.float()) reward_chosen = log_pi_chosen - log_ref_chosen # implicit reward margin reward_rejected = log_pi_rejected - log_ref_rejected loss = -F.logsigmoid(beta * (reward_chosen - reward_rejected)).mean() # ── metrics ────────────────────────────────────────────────────────────── with torch.no_grad(): acc = float(((logit_chosen > logit_rejected).float()).mean().item()) margin = float((torch.sigmoid(logit_chosen) - torch.sigmoid(logit_rejected)).mean().item()) return loss, { "dpo_loss": float(loss.detach()), "dpo_acc": acc, "prob_margin": margin, "prob_chosen": float(torch.sigmoid(logit_chosen).mean().detach()), "prob_rejected": float(torch.sigmoid(logit_rejected).mean().detach()), } # ───────────────────────────────────────────────────────────────────────────── # DPO Model wrapper # ───────────────────────────────────────────────────────────────────────────── class DPOModel(nn.Module): """ Wraps SFTModel for DPO training. Only HazardHead is trainable; everything else is frozen. Keeps a frozen reference copy of the initial SFT HazardHead. """ def __init__( self, sft_checkpoint_dir: str, use_bf16: bool = True, ): super().__init__() ckpt = Path(sft_checkpoint_dir) if not _is_sft_ckpt_dir(ckpt): raise RuntimeError(f"Not a valid SFT checkpoint: {ckpt}") with open(ckpt / "config.json") as f: cfg = json.load(f) model_name = cfg["model_name"] logger.info(f"Loading SFTModel from {ckpt} ...") self.sft = SFTModel( model_name = model_name, pretrained_lora_path = str(ckpt / "vlm_lora"), belief_strategy = cfg.get("belief_strategy", "mean_pool"), tta_intermediate_dim = cfg.get("tta_intermediate_dim", 512), use_lora = True, use_bf16 = use_bf16, device = "auto", ) load_sft_heads(self.sft, ckpt) # ── freeze everything except HazardHead ────────────────────────────── for param in self.sft.vlm.parameters(): param.requires_grad = False for param in self.sft.belief_aggregator.parameters(): param.requires_grad = False for param in self.sft.tta_head.parameters(): param.requires_grad = False # HazardHead remains trainable # ── frozen reference copy of HazardHead ────────────────────────────── self.ref_hazard_head = copy.deepcopy(self.sft.hazard_head) for param in self.ref_hazard_head.parameters(): param.requires_grad = False self.ref_hazard_head.to(self.sft.device) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) total = sum(p.numel() for p in self.parameters()) logger.info(f"Trainable params: {trainable:,} / Total: {total:,}") self.processor = self.sft.processor self.hidden_dim = self.sft.hidden_dim self._sft_ckpt_dir = ckpt # kept for save_checkpoint @property def device(self): return self.sft.device def _build_inputs( self, images: List[List], # [B, n_frames] metadata: List[dict], ) -> dict: proc = self.processor apply_chat = ( proc.apply_chat_template if hasattr(proc, "apply_chat_template") else proc.tokenizer.apply_chat_template ) texts = [] for i in range(len(images)): frames = images[i] content = [{"type": "image"} for _ in range(len(frames))] content.append({"type": "text", "text": _build_prompt(metadata[i])}) msgs = [ {"role": "system", "content": SYSTEM}, {"role": "user", "content": content}, ] texts.append(apply_chat(msgs, tokenize=False, add_generation_prompt=False)) return proc(text=texts, images=images, return_tensors="pt", padding=True, truncation=True) def forward_pair( self, chosen_images: List[List], chosen_metadata: List[dict], rejected_images: List[List], rejected_metadata:List[dict], amp_dtype: torch.dtype = torch.bfloat16, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns: logit_chosen, logit_rejected, ref_logit_chosen, ref_logit_rejected (all [B]) """ inputs_c = self._build_inputs(chosen_images, chosen_metadata) inputs_r = self._build_inputs(rejected_images, rejected_metadata) # VLM is frozen → run in no_grad to save peak memory with torch.no_grad(): with autocast(device_type="cuda", dtype=amp_dtype, enabled=True): belief_c = self.sft.encode_observation(inputs_c) belief_r = self.sft.encode_observation(inputs_r) # HazardHead forward (trainable) with autocast(device_type="cuda", dtype=amp_dtype, enabled=True): logit_c = self.sft.hazard_head(belief_c) logit_r = self.sft.hazard_head(belief_r) # Reference head (frozen) with torch.no_grad(): with autocast(device_type="cuda", dtype=amp_dtype, enabled=True): ref_c = self.ref_hazard_head(belief_c.detach()) ref_r = self.ref_hazard_head(belief_r.detach()) return logit_c, logit_r, ref_c, ref_r def save_checkpoint(self, save_dir: str, epoch: int = 0, step: int = 0): save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) # Save updated HazardHead torch.save(self.sft.hazard_head.state_dict(), save_dir / "hazard_head.pt") # Also save LoRA (unchanged) and other SFT heads for a complete loadable checkpoint lora_dir = save_dir / "vlm_lora" self.sft.vlm.save_pretrained(lora_dir) torch.save(self.sft.belief_aggregator.state_dict(), save_dir / "belief_aggregator.pt") torch.save(self.sft.tta_head.state_dict(), save_dir / "tta_head.pt") # Copy SFT config + update epoch/step with open(self._sft_ckpt_dir / "config.json") as f: cfg = json.load(f) cfg["epoch"] = epoch cfg["step"] = step with open(save_dir / "config.json", "w") as f: json.dump(cfg, f, indent=2) logger.info(f"✅ Checkpoint saved to {save_dir}") # ───────────────────────────────────────────────────────────────────────────── # DPO Trainer # ───────────────────────────────────────────────────────────────────────────── class DPOTrainer: def __init__( self, model: DPOModel, train_loader: DataLoader, val_loader: DataLoader, output_dir: str, experiment_name: str = "dpo_v1", num_epochs: int = 5, learning_rate: float = 5e-5, beta: float = 0.1, lambda_reg: float = 0.5, gradient_accumulation_steps: int = 1, max_grad_norm: float = 1.0, val_every_n_steps: int = 500, use_wandb: bool = False, ): self.model = model self.train_loader = train_loader self.val_loader = val_loader self.output_dir = Path(output_dir) self.experiment_name = experiment_name self.num_epochs = num_epochs self.beta = beta self.lambda_reg = lambda_reg self.grad_accum = gradient_accumulation_steps self.max_grad_norm = max_grad_norm self.val_every = val_every_n_steps self.use_wandb = use_wandb and HAS_WANDB self.exp_dir = self.output_dir / experiment_name self.exp_dir.mkdir(parents=True, exist_ok=True) # Only optimise HazardHead self.optimizer = AdamW( [p for p in model.parameters() if p.requires_grad], lr=learning_rate, weight_decay=0.01, ) self.global_step = 0 self.best_val_acc = float("-inf") if self.use_wandb: wandb.init(project="lkalert-dpo", name=experiment_name, config={"beta": beta, "lambda_reg": lambda_reg, "lr": learning_rate, "epochs": num_epochs}) logger.info(f"✅ DPOTrainer ready exp={experiment_name} " f"steps/epoch≈{len(train_loader)}") # ── single training step ────────────────────────────────────────────────── def train_step(self, batch: dict) -> dict: self.model.train() amp_dtype = torch.bfloat16 logit_c, logit_r, ref_c, ref_r = self.model.forward_pair( batch["chosen_images"], batch["chosen_metadata"], batch["rejected_images"], batch["rejected_metadata"], amp_dtype=amp_dtype, ) # DPO loss l_dpo, dpo_metrics = compute_dpo_loss( logit_c, logit_r, ref_c, ref_r, beta=self.beta ) # Regularisation: BCE on chosen (should be 1) and rejected (should be 0) ones = torch.ones_like(logit_c.float()) zeros = torch.zeros_like(logit_r.float()) l_reg = 0.5 * (F.binary_cross_entropy_with_logits(logit_c.float(), ones) + F.binary_cross_entropy_with_logits(logit_r.float(), zeros)) loss = l_dpo + self.lambda_reg * l_reg loss = loss / self.grad_accum loss.backward() return {**dpo_metrics, "reg_loss": float(l_reg.detach()), "total_loss": float((l_dpo + self.lambda_reg * l_reg).detach())} # ── validation loop ─────────────────────────────────────────────────────── @torch.no_grad() def validate(self) -> dict: self.model.eval() amp_dtype = torch.bfloat16 accs, margins = [], [] prob_c_list, prob_r_list = [], [] for batch in tqdm(self.val_loader, desc=" Val", ncols=70, leave=False): logit_c, logit_r, ref_c, ref_r = self.model.forward_pair( batch["chosen_images"], batch["chosen_metadata"], batch["rejected_images"], batch["rejected_metadata"], amp_dtype=amp_dtype, ) _, m = compute_dpo_loss(logit_c, logit_r, ref_c, ref_r, beta=self.beta) accs.append(m["dpo_acc"]) margins.append(m["prob_margin"]) prob_c_list.append(m["prob_chosen"]) prob_r_list.append(m["prob_rejected"]) return { "val_dpo_acc": float(np.mean(accs)), "val_prob_margin": float(np.mean(margins)), "val_prob_chosen": float(np.mean(prob_c_list)), "val_prob_rejected": float(np.mean(prob_r_list)), } # ── main training loop ──────────────────────────────────────────────────── def train(self): logger.info("=" * 60) logger.info(f"Starting DPO training: {self.experiment_name}") logger.info("=" * 60) for epoch in range(self.num_epochs): self.optimizer.zero_grad() accum_metrics: Dict[str, List[float]] = {} pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.num_epochs}", ncols=80) for step_in_epoch, batch in enumerate(pbar): metrics = self.train_step(batch) self.global_step += 1 for k, v in metrics.items(): accum_metrics.setdefault(k, []).append(v) # Optimiser update if self.global_step % self.grad_accum == 0: nn.utils.clip_grad_norm_( [p for p in self.model.parameters() if p.requires_grad], self.max_grad_norm, ) self.optimizer.step() self.optimizer.zero_grad() pbar.set_postfix({ "dpo": f"{metrics.get('dpo_loss', 0):.3f}", "acc": f"{metrics.get('dpo_acc', 0):.3f}", }) # Periodic validation if self.global_step % self.val_every == 0: val = self.validate() avg = {k: float(np.mean(v)) for k, v in accum_metrics.items()} logger.info( f"Step {self.global_step:6d} | " f"dpo_loss={avg.get('dpo_loss', 0):.3f} " f"train_acc={avg.get('dpo_acc', 0):.3f} " f"val_acc={val['val_dpo_acc']:.3f} " f"margin={val['val_prob_margin']:.3f}" ) if self.use_wandb: wandb.log({**avg, **val, "step": self.global_step}) if val["val_dpo_acc"] > self.best_val_acc: self.best_val_acc = val["val_dpo_acc"] self.model.save_checkpoint( str(self.exp_dir / "best"), epoch=epoch, step=self.global_step, ) logger.info(f" ✅ New best val_acc={self.best_val_acc:.4f}") accum_metrics = {} # Epoch-end validation val = self.validate() logger.info( f"Epoch {epoch+1} end | " f"val_acc={val['val_dpo_acc']:.3f} " f"margin={val['val_prob_margin']:.3f} " f"P(chosen)={val['val_prob_chosen']:.3f} " f"P(rejected)={val['val_prob_rejected']:.3f}" ) # Save epoch checkpoint self.model.save_checkpoint( str(self.exp_dir / f"epoch_{epoch+1}"), epoch=epoch, step=self.global_step, ) if val["val_dpo_acc"] > self.best_val_acc: self.best_val_acc = val["val_dpo_acc"] self.model.save_checkpoint( str(self.exp_dir / "best"), epoch=epoch, step=self.global_step, ) logger.info(f"Training complete. Best val_dpo_acc={self.best_val_acc:.4f}") # ───────────────────────────────────────────────────────────────────────────── # Main # ───────────────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser("DPO trainer") parser.add_argument("--sft_checkpoint", required=True, help="Path to SFT best checkpoint dir") parser.add_argument("--pair_dir", default="data/dpo_pairs") parser.add_argument("--output_dir", default="checkpoints/DPO") parser.add_argument("--experiment_name", default="dpo_v1") parser.add_argument("--num_epochs", type=int, default=5) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=5e-5) parser.add_argument("--beta", type=float, default=0.1, help="DPO temperature β") parser.add_argument("--lambda_reg", type=float, default=0.5, help="SFT regularisation weight") parser.add_argument("--gradient_accumulation_steps", type=int, default=2) parser.add_argument("--max_grad_norm", type=float, default=1.0) parser.add_argument("--val_every_n_steps",type=int, default=500) parser.add_argument("--use_wandb", action="store_true") parser.add_argument("--debug", action="store_true") parser.add_argument("--debug_samples", type=int, default=64) args = parser.parse_args() pair_dir = Path(args.pair_dir) train_manifests = [ pair_dir / "nexar_train.json", pair_dir / "dada_train.json", ] val_manifests = [ pair_dir / "nexar_val.json", pair_dir / "dada_val.json", ] train_ds = DPODataset(train_manifests, split="train", debug=args.debug, debug_samples=args.debug_samples) val_ds = DPODataset(val_manifests, split="val", debug=args.debug, debug_samples=args.debug_samples // 4) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=dpo_collate_fn, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=dpo_collate_fn, num_workers=4, pin_memory=True) model = DPOModel(sft_checkpoint_dir=args.sft_checkpoint, use_bf16=True) trainer = DPOTrainer( model = model, train_loader = train_loader, val_loader = val_loader, output_dir = args.output_dir, experiment_name = args.experiment_name, num_epochs = args.num_epochs, learning_rate = args.learning_rate, beta = args.beta, lambda_reg = args.lambda_reg, gradient_accumulation_steps = args.gradient_accumulation_steps, max_grad_norm = args.max_grad_norm, val_every_n_steps= args.val_every_n_steps, use_wandb = args.use_wandb, ) trainer.train() if __name__ == "__main__": main()