| |
| """ |
| DPO Trainer β aligns HazardHead alert timing via Direct Preference Optimization. |
| |
| Architecture |
| ------------ |
| Base: SFTModel (VLM + LoRA + BeliefAggregator + HazardHead + TTAHead) |
| loaded from SFT best checkpoint; VLM / TTAHead / BeliefAggregator FROZEN. |
| |
| Trainable: HazardHead only (~2 k params) |
| |
| Reference: frozen copy of the initial SFT HazardHead (for DPO implicit reward) |
| |
| Loss |
| ---- |
| L = L_DPO + lambda_reg * L_reg |
| |
| L_DPO = -log Ο(Ξ² Β· [(log P_ΞΈ(alert|chosen) - log P_ref(alert|chosen)) |
| - (log P_ΞΈ(alert|rejected) - log P_ref(alert|rejected))]) |
| |
| L_reg = BCE(logit_chosen, 1) # keep detecting hazards in chosen windows |
| + BCE(logit_rejected, 0) # keep suppressing hazards in rejected windows |
| |
| Checkpoint selection: val DPO accuracy |
| = fraction of pairs where P_ΞΈ(alert|chosen) > P_ΞΈ(alert|rejected) |
| |
| Usage |
| ----- |
| python -m training.DPO.trainer \ |
| --sft_checkpoint checkpoints/SFT/sft_v2/best \ |
| --pair_dir data/dpo_pairs \ |
| --output_dir checkpoints/DPO \ |
| --experiment_name dpo_v1 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import copy |
| import json |
| import logging |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.amp import autocast |
| from torch.optim import AdamW |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| try: |
| import wandb |
| HAS_WANDB = True |
| except ImportError: |
| HAS_WANDB = False |
|
|
| from .dataset import DPODataset, dpo_collate_fn |
|
|
| |
| import sys |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) |
| from training.SFT.trainer import SFTModel, load_sft_heads, _is_sft_ckpt_dir |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger("DPO.trainer") |
|
|
| SYSTEM = "You are a driving safety AI analyzing dashcam footage for collision risk." |
|
|
|
|
| |
| |
| |
|
|
| def _build_prompt(metadata: dict) -> str: |
| parts = [] |
| if metadata.get("weather"): parts.append(f"Weather: {metadata['weather']}") |
| if metadata.get("road_type"): parts.append(f"Road: {metadata['road_type']}") |
| if metadata.get("time_of_day"): parts.append(f"Time: {metadata['time_of_day']}") |
| ctx = ", ".join(parts) or "Urban driving" |
| return ( |
| f"Analyze this driving sequence.\n" |
| f"Context: {ctx}\n" |
| f"Estimate the time to potential collision. Output a single number in seconds." |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def compute_dpo_loss( |
| logit_chosen: torch.Tensor, |
| logit_rejected: torch.Tensor, |
| ref_logit_chosen: torch.Tensor, |
| ref_logit_rejected: torch.Tensor, |
| beta: float = 0.1, |
| ) -> Tuple[torch.Tensor, Dict[str, float]]: |
| """ |
| Standard DPO loss for binary alert policy. |
| |
| log P(alert | x) = log Ο(logit) [binary action] |
| """ |
| |
| log_pi_chosen = -F.softplus(-logit_chosen.float()) |
| log_pi_rejected = -F.softplus(-logit_rejected.float()) |
|
|
| |
| with torch.no_grad(): |
| log_ref_chosen = -F.softplus(-ref_logit_chosen.float()) |
| log_ref_rejected = -F.softplus(-ref_logit_rejected.float()) |
|
|
| reward_chosen = log_pi_chosen - log_ref_chosen |
| reward_rejected = log_pi_rejected - log_ref_rejected |
|
|
| loss = -F.logsigmoid(beta * (reward_chosen - reward_rejected)).mean() |
|
|
| |
| with torch.no_grad(): |
| acc = float(((logit_chosen > logit_rejected).float()).mean().item()) |
| margin = float((torch.sigmoid(logit_chosen) - torch.sigmoid(logit_rejected)).mean().item()) |
|
|
| return loss, { |
| "dpo_loss": float(loss.detach()), |
| "dpo_acc": acc, |
| "prob_margin": margin, |
| "prob_chosen": float(torch.sigmoid(logit_chosen).mean().detach()), |
| "prob_rejected": float(torch.sigmoid(logit_rejected).mean().detach()), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class DPOModel(nn.Module): |
| """ |
| Wraps SFTModel for DPO training. |
| |
| Only HazardHead is trainable; everything else is frozen. |
| Keeps a frozen reference copy of the initial SFT HazardHead. |
| """ |
|
|
| def __init__( |
| self, |
| sft_checkpoint_dir: str, |
| use_bf16: bool = True, |
| ): |
| super().__init__() |
| ckpt = Path(sft_checkpoint_dir) |
| if not _is_sft_ckpt_dir(ckpt): |
| raise RuntimeError(f"Not a valid SFT checkpoint: {ckpt}") |
|
|
| with open(ckpt / "config.json") as f: |
| cfg = json.load(f) |
|
|
| model_name = cfg["model_name"] |
|
|
| logger.info(f"Loading SFTModel from {ckpt} ...") |
| self.sft = SFTModel( |
| model_name = model_name, |
| pretrained_lora_path = str(ckpt / "vlm_lora"), |
| belief_strategy = cfg.get("belief_strategy", "mean_pool"), |
| tta_intermediate_dim = cfg.get("tta_intermediate_dim", 512), |
| use_lora = True, |
| use_bf16 = use_bf16, |
| device = "auto", |
| ) |
| load_sft_heads(self.sft, ckpt) |
|
|
| |
| for param in self.sft.vlm.parameters(): |
| param.requires_grad = False |
| for param in self.sft.belief_aggregator.parameters(): |
| param.requires_grad = False |
| for param in self.sft.tta_head.parameters(): |
| param.requires_grad = False |
| |
|
|
| |
| self.ref_hazard_head = copy.deepcopy(self.sft.hazard_head) |
| for param in self.ref_hazard_head.parameters(): |
| param.requires_grad = False |
| self.ref_hazard_head.to(self.sft.device) |
|
|
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| total = sum(p.numel() for p in self.parameters()) |
| logger.info(f"Trainable params: {trainable:,} / Total: {total:,}") |
|
|
| self.processor = self.sft.processor |
| self.hidden_dim = self.sft.hidden_dim |
| self._sft_ckpt_dir = ckpt |
|
|
| @property |
| def device(self): |
| return self.sft.device |
|
|
| def _build_inputs( |
| self, |
| images: List[List], |
| metadata: List[dict], |
| ) -> dict: |
| proc = self.processor |
| apply_chat = ( |
| proc.apply_chat_template |
| if hasattr(proc, "apply_chat_template") |
| else proc.tokenizer.apply_chat_template |
| ) |
| texts = [] |
| for i in range(len(images)): |
| frames = images[i] |
| content = [{"type": "image"} for _ in range(len(frames))] |
| content.append({"type": "text", "text": _build_prompt(metadata[i])}) |
| msgs = [ |
| {"role": "system", "content": SYSTEM}, |
| {"role": "user", "content": content}, |
| ] |
| texts.append(apply_chat(msgs, tokenize=False, add_generation_prompt=False)) |
| return proc(text=texts, images=images, |
| return_tensors="pt", padding=True, truncation=True) |
|
|
| def forward_pair( |
| self, |
| chosen_images: List[List], |
| chosen_metadata: List[dict], |
| rejected_images: List[List], |
| rejected_metadata:List[dict], |
| amp_dtype: torch.dtype = torch.bfloat16, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Returns: |
| logit_chosen, logit_rejected, |
| ref_logit_chosen, ref_logit_rejected (all [B]) |
| """ |
| inputs_c = self._build_inputs(chosen_images, chosen_metadata) |
| inputs_r = self._build_inputs(rejected_images, rejected_metadata) |
|
|
| |
| with torch.no_grad(): |
| with autocast(device_type="cuda", dtype=amp_dtype, enabled=True): |
| belief_c = self.sft.encode_observation(inputs_c) |
| belief_r = self.sft.encode_observation(inputs_r) |
|
|
| |
| with autocast(device_type="cuda", dtype=amp_dtype, enabled=True): |
| logit_c = self.sft.hazard_head(belief_c) |
| logit_r = self.sft.hazard_head(belief_r) |
|
|
| |
| with torch.no_grad(): |
| with autocast(device_type="cuda", dtype=amp_dtype, enabled=True): |
| ref_c = self.ref_hazard_head(belief_c.detach()) |
| ref_r = self.ref_hazard_head(belief_r.detach()) |
|
|
| return logit_c, logit_r, ref_c, ref_r |
|
|
| def save_checkpoint(self, save_dir: str, epoch: int = 0, step: int = 0): |
| save_dir = Path(save_dir) |
| save_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| torch.save(self.sft.hazard_head.state_dict(), save_dir / "hazard_head.pt") |
| |
| lora_dir = save_dir / "vlm_lora" |
| self.sft.vlm.save_pretrained(lora_dir) |
| torch.save(self.sft.belief_aggregator.state_dict(), save_dir / "belief_aggregator.pt") |
| torch.save(self.sft.tta_head.state_dict(), save_dir / "tta_head.pt") |
|
|
| |
| with open(self._sft_ckpt_dir / "config.json") as f: |
| cfg = json.load(f) |
| cfg["epoch"] = epoch |
| cfg["step"] = step |
| with open(save_dir / "config.json", "w") as f: |
| json.dump(cfg, f, indent=2) |
|
|
| logger.info(f"β
Checkpoint saved to {save_dir}") |
|
|
|
|
| |
| |
| |
|
|
| class DPOTrainer: |
|
|
| def __init__( |
| self, |
| model: DPOModel, |
| train_loader: DataLoader, |
| val_loader: DataLoader, |
| output_dir: str, |
| experiment_name: str = "dpo_v1", |
| num_epochs: int = 5, |
| learning_rate: float = 5e-5, |
| beta: float = 0.1, |
| lambda_reg: float = 0.5, |
| gradient_accumulation_steps: int = 1, |
| max_grad_norm: float = 1.0, |
| val_every_n_steps: int = 500, |
| use_wandb: bool = False, |
| ): |
| self.model = model |
| self.train_loader = train_loader |
| self.val_loader = val_loader |
| self.output_dir = Path(output_dir) |
| self.experiment_name = experiment_name |
| self.num_epochs = num_epochs |
| self.beta = beta |
| self.lambda_reg = lambda_reg |
| self.grad_accum = gradient_accumulation_steps |
| self.max_grad_norm = max_grad_norm |
| self.val_every = val_every_n_steps |
| self.use_wandb = use_wandb and HAS_WANDB |
|
|
| self.exp_dir = self.output_dir / experiment_name |
| self.exp_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| self.optimizer = AdamW( |
| [p for p in model.parameters() if p.requires_grad], |
| lr=learning_rate, |
| weight_decay=0.01, |
| ) |
| self.global_step = 0 |
| self.best_val_acc = float("-inf") |
|
|
| if self.use_wandb: |
| wandb.init(project="lkalert-dpo", name=experiment_name, |
| config={"beta": beta, "lambda_reg": lambda_reg, |
| "lr": learning_rate, "epochs": num_epochs}) |
| logger.info(f"β
DPOTrainer ready exp={experiment_name} " |
| f"steps/epochβ{len(train_loader)}") |
|
|
| |
|
|
| def train_step(self, batch: dict) -> dict: |
| self.model.train() |
| amp_dtype = torch.bfloat16 |
|
|
| logit_c, logit_r, ref_c, ref_r = self.model.forward_pair( |
| batch["chosen_images"], batch["chosen_metadata"], |
| batch["rejected_images"], batch["rejected_metadata"], |
| amp_dtype=amp_dtype, |
| ) |
|
|
| |
| l_dpo, dpo_metrics = compute_dpo_loss( |
| logit_c, logit_r, ref_c, ref_r, beta=self.beta |
| ) |
|
|
| |
| ones = torch.ones_like(logit_c.float()) |
| zeros = torch.zeros_like(logit_r.float()) |
| l_reg = 0.5 * (F.binary_cross_entropy_with_logits(logit_c.float(), ones) |
| + F.binary_cross_entropy_with_logits(logit_r.float(), zeros)) |
|
|
| loss = l_dpo + self.lambda_reg * l_reg |
|
|
| loss = loss / self.grad_accum |
| loss.backward() |
|
|
| return {**dpo_metrics, |
| "reg_loss": float(l_reg.detach()), |
| "total_loss": float((l_dpo + self.lambda_reg * l_reg).detach())} |
|
|
| |
|
|
| @torch.no_grad() |
| def validate(self) -> dict: |
| self.model.eval() |
| amp_dtype = torch.bfloat16 |
|
|
| accs, margins = [], [] |
| prob_c_list, prob_r_list = [], [] |
|
|
| for batch in tqdm(self.val_loader, desc=" Val", ncols=70, leave=False): |
| logit_c, logit_r, ref_c, ref_r = self.model.forward_pair( |
| batch["chosen_images"], batch["chosen_metadata"], |
| batch["rejected_images"], batch["rejected_metadata"], |
| amp_dtype=amp_dtype, |
| ) |
| _, m = compute_dpo_loss(logit_c, logit_r, ref_c, ref_r, beta=self.beta) |
| accs.append(m["dpo_acc"]) |
| margins.append(m["prob_margin"]) |
| prob_c_list.append(m["prob_chosen"]) |
| prob_r_list.append(m["prob_rejected"]) |
|
|
| return { |
| "val_dpo_acc": float(np.mean(accs)), |
| "val_prob_margin": float(np.mean(margins)), |
| "val_prob_chosen": float(np.mean(prob_c_list)), |
| "val_prob_rejected": float(np.mean(prob_r_list)), |
| } |
|
|
| |
|
|
| def train(self): |
| logger.info("=" * 60) |
| logger.info(f"Starting DPO training: {self.experiment_name}") |
| logger.info("=" * 60) |
|
|
| for epoch in range(self.num_epochs): |
| self.optimizer.zero_grad() |
| accum_metrics: Dict[str, List[float]] = {} |
|
|
| pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.num_epochs}", |
| ncols=80) |
|
|
| for step_in_epoch, batch in enumerate(pbar): |
| metrics = self.train_step(batch) |
| self.global_step += 1 |
|
|
| for k, v in metrics.items(): |
| accum_metrics.setdefault(k, []).append(v) |
|
|
| |
| if self.global_step % self.grad_accum == 0: |
| nn.utils.clip_grad_norm_( |
| [p for p in self.model.parameters() if p.requires_grad], |
| self.max_grad_norm, |
| ) |
| self.optimizer.step() |
| self.optimizer.zero_grad() |
|
|
| pbar.set_postfix({ |
| "dpo": f"{metrics.get('dpo_loss', 0):.3f}", |
| "acc": f"{metrics.get('dpo_acc', 0):.3f}", |
| }) |
|
|
| |
| if self.global_step % self.val_every == 0: |
| val = self.validate() |
| avg = {k: float(np.mean(v)) for k, v in accum_metrics.items()} |
| logger.info( |
| f"Step {self.global_step:6d} | " |
| f"dpo_loss={avg.get('dpo_loss', 0):.3f} " |
| f"train_acc={avg.get('dpo_acc', 0):.3f} " |
| f"val_acc={val['val_dpo_acc']:.3f} " |
| f"margin={val['val_prob_margin']:.3f}" |
| ) |
| if self.use_wandb: |
| wandb.log({**avg, **val, "step": self.global_step}) |
|
|
| if val["val_dpo_acc"] > self.best_val_acc: |
| self.best_val_acc = val["val_dpo_acc"] |
| self.model.save_checkpoint( |
| str(self.exp_dir / "best"), |
| epoch=epoch, step=self.global_step, |
| ) |
| logger.info(f" β
New best val_acc={self.best_val_acc:.4f}") |
|
|
| accum_metrics = {} |
|
|
| |
| val = self.validate() |
| logger.info( |
| f"Epoch {epoch+1} end | " |
| f"val_acc={val['val_dpo_acc']:.3f} " |
| f"margin={val['val_prob_margin']:.3f} " |
| f"P(chosen)={val['val_prob_chosen']:.3f} " |
| f"P(rejected)={val['val_prob_rejected']:.3f}" |
| ) |
|
|
| |
| self.model.save_checkpoint( |
| str(self.exp_dir / f"epoch_{epoch+1}"), |
| epoch=epoch, step=self.global_step, |
| ) |
|
|
| if val["val_dpo_acc"] > self.best_val_acc: |
| self.best_val_acc = val["val_dpo_acc"] |
| self.model.save_checkpoint( |
| str(self.exp_dir / "best"), |
| epoch=epoch, step=self.global_step, |
| ) |
|
|
| logger.info(f"Training complete. Best val_dpo_acc={self.best_val_acc:.4f}") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser("DPO trainer") |
| parser.add_argument("--sft_checkpoint", required=True, |
| help="Path to SFT best checkpoint dir") |
| parser.add_argument("--pair_dir", default="data/dpo_pairs") |
| parser.add_argument("--output_dir", default="checkpoints/DPO") |
| parser.add_argument("--experiment_name", default="dpo_v1") |
| parser.add_argument("--num_epochs", type=int, default=5) |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--learning_rate", type=float, default=5e-5) |
| parser.add_argument("--beta", type=float, default=0.1, |
| help="DPO temperature Ξ²") |
| parser.add_argument("--lambda_reg", type=float, default=0.5, |
| help="SFT regularisation weight") |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=2) |
| parser.add_argument("--max_grad_norm", type=float, default=1.0) |
| parser.add_argument("--val_every_n_steps",type=int, default=500) |
| parser.add_argument("--use_wandb", action="store_true") |
| parser.add_argument("--debug", action="store_true") |
| parser.add_argument("--debug_samples", type=int, default=64) |
| args = parser.parse_args() |
|
|
| pair_dir = Path(args.pair_dir) |
|
|
| train_manifests = [ |
| pair_dir / "nexar_train.json", |
| pair_dir / "dada_train.json", |
| ] |
| val_manifests = [ |
| pair_dir / "nexar_val.json", |
| pair_dir / "dada_val.json", |
| ] |
|
|
| train_ds = DPODataset(train_manifests, split="train", |
| debug=args.debug, debug_samples=args.debug_samples) |
| val_ds = DPODataset(val_manifests, split="val", |
| debug=args.debug, debug_samples=args.debug_samples // 4) |
|
|
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, |
| collate_fn=dpo_collate_fn, num_workers=4, pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, |
| collate_fn=dpo_collate_fn, num_workers=4, pin_memory=True) |
|
|
| model = DPOModel(sft_checkpoint_dir=args.sft_checkpoint, use_bf16=True) |
|
|
| trainer = DPOTrainer( |
| model = model, |
| train_loader = train_loader, |
| val_loader = val_loader, |
| output_dir = args.output_dir, |
| experiment_name = args.experiment_name, |
| num_epochs = args.num_epochs, |
| learning_rate = args.learning_rate, |
| beta = args.beta, |
| lambda_reg = args.lambda_reg, |
| gradient_accumulation_steps = args.gradient_accumulation_steps, |
| max_grad_norm = args.max_grad_norm, |
| val_every_n_steps= args.val_every_n_steps, |
| use_wandb = args.use_wandb, |
| ) |
| trainer.train() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|