VLAlert / training /DPO /trainer.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
23.4 kB
#!/usr/bin/env python3
"""
DPO Trainer β€” aligns HazardHead alert timing via Direct Preference Optimization.
Architecture
------------
Base: SFTModel (VLM + LoRA + BeliefAggregator + HazardHead + TTAHead)
loaded from SFT best checkpoint; VLM / TTAHead / BeliefAggregator FROZEN.
Trainable: HazardHead only (~2 k params)
Reference: frozen copy of the initial SFT HazardHead (for DPO implicit reward)
Loss
----
L = L_DPO + lambda_reg * L_reg
L_DPO = -log Οƒ(Ξ² Β· [(log P_ΞΈ(alert|chosen) - log P_ref(alert|chosen))
- (log P_ΞΈ(alert|rejected) - log P_ref(alert|rejected))])
L_reg = BCE(logit_chosen, 1) # keep detecting hazards in chosen windows
+ BCE(logit_rejected, 0) # keep suppressing hazards in rejected windows
Checkpoint selection: val DPO accuracy
= fraction of pairs where P_ΞΈ(alert|chosen) > P_ΞΈ(alert|rejected)
Usage
-----
python -m training.DPO.trainer \
--sft_checkpoint checkpoints/SFT/sft_v2/best \
--pair_dir data/dpo_pairs \
--output_dir checkpoints/DPO \
--experiment_name dpo_v1
"""
from __future__ import annotations
import argparse
import copy
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
from .dataset import DPODataset, dpo_collate_fn
# Import SFT infrastructure
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
from training.SFT.trainer import SFTModel, load_sft_heads, _is_sft_ckpt_dir
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("DPO.trainer")
SYSTEM = "You are a driving safety AI analyzing dashcam footage for collision risk."
# ─────────────────────────────────────────────────────────────────────────────
# Prompt builder (identical to SFT evaluate.py)
# ─────────────────────────────────────────────────────────────────────────────
def _build_prompt(metadata: dict) -> str:
parts = []
if metadata.get("weather"): parts.append(f"Weather: {metadata['weather']}")
if metadata.get("road_type"): parts.append(f"Road: {metadata['road_type']}")
if metadata.get("time_of_day"): parts.append(f"Time: {metadata['time_of_day']}")
ctx = ", ".join(parts) or "Urban driving"
return (
f"Analyze this driving sequence.\n"
f"Context: {ctx}\n"
f"Estimate the time to potential collision. Output a single number in seconds."
)
# ─────────────────────────────────────────────────────────────────────────────
# DPO loss
# ─────────────────────────────────────────────────────────────────────────────
def compute_dpo_loss(
logit_chosen: torch.Tensor, # [B] policy logit for chosen window
logit_rejected: torch.Tensor, # [B] policy logit for rejected window
ref_logit_chosen: torch.Tensor, # [B] reference logit (frozen)
ref_logit_rejected: torch.Tensor, # [B] reference logit (frozen)
beta: float = 0.1,
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Standard DPO loss for binary alert policy.
log P(alert | x) = log Οƒ(logit) [binary action]
"""
# log Ο€_ΞΈ(alert | Β·)
log_pi_chosen = -F.softplus(-logit_chosen.float())
log_pi_rejected = -F.softplus(-logit_rejected.float())
# log Ο€_ref(alert | Β·)
with torch.no_grad():
log_ref_chosen = -F.softplus(-ref_logit_chosen.float())
log_ref_rejected = -F.softplus(-ref_logit_rejected.float())
reward_chosen = log_pi_chosen - log_ref_chosen # implicit reward margin
reward_rejected = log_pi_rejected - log_ref_rejected
loss = -F.logsigmoid(beta * (reward_chosen - reward_rejected)).mean()
# ── metrics ──────────────────────────────────────────────────────────────
with torch.no_grad():
acc = float(((logit_chosen > logit_rejected).float()).mean().item())
margin = float((torch.sigmoid(logit_chosen) - torch.sigmoid(logit_rejected)).mean().item())
return loss, {
"dpo_loss": float(loss.detach()),
"dpo_acc": acc,
"prob_margin": margin,
"prob_chosen": float(torch.sigmoid(logit_chosen).mean().detach()),
"prob_rejected": float(torch.sigmoid(logit_rejected).mean().detach()),
}
# ─────────────────────────────────────────────────────────────────────────────
# DPO Model wrapper
# ─────────────────────────────────────────────────────────────────────────────
class DPOModel(nn.Module):
"""
Wraps SFTModel for DPO training.
Only HazardHead is trainable; everything else is frozen.
Keeps a frozen reference copy of the initial SFT HazardHead.
"""
def __init__(
self,
sft_checkpoint_dir: str,
use_bf16: bool = True,
):
super().__init__()
ckpt = Path(sft_checkpoint_dir)
if not _is_sft_ckpt_dir(ckpt):
raise RuntimeError(f"Not a valid SFT checkpoint: {ckpt}")
with open(ckpt / "config.json") as f:
cfg = json.load(f)
model_name = cfg["model_name"]
logger.info(f"Loading SFTModel from {ckpt} ...")
self.sft = SFTModel(
model_name = model_name,
pretrained_lora_path = str(ckpt / "vlm_lora"),
belief_strategy = cfg.get("belief_strategy", "mean_pool"),
tta_intermediate_dim = cfg.get("tta_intermediate_dim", 512),
use_lora = True,
use_bf16 = use_bf16,
device = "auto",
)
load_sft_heads(self.sft, ckpt)
# ── freeze everything except HazardHead ──────────────────────────────
for param in self.sft.vlm.parameters():
param.requires_grad = False
for param in self.sft.belief_aggregator.parameters():
param.requires_grad = False
for param in self.sft.tta_head.parameters():
param.requires_grad = False
# HazardHead remains trainable
# ── frozen reference copy of HazardHead ──────────────────────────────
self.ref_hazard_head = copy.deepcopy(self.sft.hazard_head)
for param in self.ref_hazard_head.parameters():
param.requires_grad = False
self.ref_hazard_head.to(self.sft.device)
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
logger.info(f"Trainable params: {trainable:,} / Total: {total:,}")
self.processor = self.sft.processor
self.hidden_dim = self.sft.hidden_dim
self._sft_ckpt_dir = ckpt # kept for save_checkpoint
@property
def device(self):
return self.sft.device
def _build_inputs(
self,
images: List[List], # [B, n_frames]
metadata: List[dict],
) -> dict:
proc = self.processor
apply_chat = (
proc.apply_chat_template
if hasattr(proc, "apply_chat_template")
else proc.tokenizer.apply_chat_template
)
texts = []
for i in range(len(images)):
frames = images[i]
content = [{"type": "image"} for _ in range(len(frames))]
content.append({"type": "text", "text": _build_prompt(metadata[i])})
msgs = [
{"role": "system", "content": SYSTEM},
{"role": "user", "content": content},
]
texts.append(apply_chat(msgs, tokenize=False, add_generation_prompt=False))
return proc(text=texts, images=images,
return_tensors="pt", padding=True, truncation=True)
def forward_pair(
self,
chosen_images: List[List],
chosen_metadata: List[dict],
rejected_images: List[List],
rejected_metadata:List[dict],
amp_dtype: torch.dtype = torch.bfloat16,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
logit_chosen, logit_rejected,
ref_logit_chosen, ref_logit_rejected (all [B])
"""
inputs_c = self._build_inputs(chosen_images, chosen_metadata)
inputs_r = self._build_inputs(rejected_images, rejected_metadata)
# VLM is frozen β†’ run in no_grad to save peak memory
with torch.no_grad():
with autocast(device_type="cuda", dtype=amp_dtype, enabled=True):
belief_c = self.sft.encode_observation(inputs_c)
belief_r = self.sft.encode_observation(inputs_r)
# HazardHead forward (trainable)
with autocast(device_type="cuda", dtype=amp_dtype, enabled=True):
logit_c = self.sft.hazard_head(belief_c)
logit_r = self.sft.hazard_head(belief_r)
# Reference head (frozen)
with torch.no_grad():
with autocast(device_type="cuda", dtype=amp_dtype, enabled=True):
ref_c = self.ref_hazard_head(belief_c.detach())
ref_r = self.ref_hazard_head(belief_r.detach())
return logit_c, logit_r, ref_c, ref_r
def save_checkpoint(self, save_dir: str, epoch: int = 0, step: int = 0):
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
# Save updated HazardHead
torch.save(self.sft.hazard_head.state_dict(), save_dir / "hazard_head.pt")
# Also save LoRA (unchanged) and other SFT heads for a complete loadable checkpoint
lora_dir = save_dir / "vlm_lora"
self.sft.vlm.save_pretrained(lora_dir)
torch.save(self.sft.belief_aggregator.state_dict(), save_dir / "belief_aggregator.pt")
torch.save(self.sft.tta_head.state_dict(), save_dir / "tta_head.pt")
# Copy SFT config + update epoch/step
with open(self._sft_ckpt_dir / "config.json") as f:
cfg = json.load(f)
cfg["epoch"] = epoch
cfg["step"] = step
with open(save_dir / "config.json", "w") as f:
json.dump(cfg, f, indent=2)
logger.info(f"βœ… Checkpoint saved to {save_dir}")
# ─────────────────────────────────────────────────────────────────────────────
# DPO Trainer
# ─────────────────────────────────────────────────────────────────────────────
class DPOTrainer:
def __init__(
self,
model: DPOModel,
train_loader: DataLoader,
val_loader: DataLoader,
output_dir: str,
experiment_name: str = "dpo_v1",
num_epochs: int = 5,
learning_rate: float = 5e-5,
beta: float = 0.1,
lambda_reg: float = 0.5,
gradient_accumulation_steps: int = 1,
max_grad_norm: float = 1.0,
val_every_n_steps: int = 500,
use_wandb: bool = False,
):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.output_dir = Path(output_dir)
self.experiment_name = experiment_name
self.num_epochs = num_epochs
self.beta = beta
self.lambda_reg = lambda_reg
self.grad_accum = gradient_accumulation_steps
self.max_grad_norm = max_grad_norm
self.val_every = val_every_n_steps
self.use_wandb = use_wandb and HAS_WANDB
self.exp_dir = self.output_dir / experiment_name
self.exp_dir.mkdir(parents=True, exist_ok=True)
# Only optimise HazardHead
self.optimizer = AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=learning_rate,
weight_decay=0.01,
)
self.global_step = 0
self.best_val_acc = float("-inf")
if self.use_wandb:
wandb.init(project="lkalert-dpo", name=experiment_name,
config={"beta": beta, "lambda_reg": lambda_reg,
"lr": learning_rate, "epochs": num_epochs})
logger.info(f"βœ… DPOTrainer ready exp={experiment_name} "
f"steps/epochβ‰ˆ{len(train_loader)}")
# ── single training step ──────────────────────────────────────────────────
def train_step(self, batch: dict) -> dict:
self.model.train()
amp_dtype = torch.bfloat16
logit_c, logit_r, ref_c, ref_r = self.model.forward_pair(
batch["chosen_images"], batch["chosen_metadata"],
batch["rejected_images"], batch["rejected_metadata"],
amp_dtype=amp_dtype,
)
# DPO loss
l_dpo, dpo_metrics = compute_dpo_loss(
logit_c, logit_r, ref_c, ref_r, beta=self.beta
)
# Regularisation: BCE on chosen (should be 1) and rejected (should be 0)
ones = torch.ones_like(logit_c.float())
zeros = torch.zeros_like(logit_r.float())
l_reg = 0.5 * (F.binary_cross_entropy_with_logits(logit_c.float(), ones)
+ F.binary_cross_entropy_with_logits(logit_r.float(), zeros))
loss = l_dpo + self.lambda_reg * l_reg
loss = loss / self.grad_accum
loss.backward()
return {**dpo_metrics,
"reg_loss": float(l_reg.detach()),
"total_loss": float((l_dpo + self.lambda_reg * l_reg).detach())}
# ── validation loop ───────────────────────────────────────────────────────
@torch.no_grad()
def validate(self) -> dict:
self.model.eval()
amp_dtype = torch.bfloat16
accs, margins = [], []
prob_c_list, prob_r_list = [], []
for batch in tqdm(self.val_loader, desc=" Val", ncols=70, leave=False):
logit_c, logit_r, ref_c, ref_r = self.model.forward_pair(
batch["chosen_images"], batch["chosen_metadata"],
batch["rejected_images"], batch["rejected_metadata"],
amp_dtype=amp_dtype,
)
_, m = compute_dpo_loss(logit_c, logit_r, ref_c, ref_r, beta=self.beta)
accs.append(m["dpo_acc"])
margins.append(m["prob_margin"])
prob_c_list.append(m["prob_chosen"])
prob_r_list.append(m["prob_rejected"])
return {
"val_dpo_acc": float(np.mean(accs)),
"val_prob_margin": float(np.mean(margins)),
"val_prob_chosen": float(np.mean(prob_c_list)),
"val_prob_rejected": float(np.mean(prob_r_list)),
}
# ── main training loop ────────────────────────────────────────────────────
def train(self):
logger.info("=" * 60)
logger.info(f"Starting DPO training: {self.experiment_name}")
logger.info("=" * 60)
for epoch in range(self.num_epochs):
self.optimizer.zero_grad()
accum_metrics: Dict[str, List[float]] = {}
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.num_epochs}",
ncols=80)
for step_in_epoch, batch in enumerate(pbar):
metrics = self.train_step(batch)
self.global_step += 1
for k, v in metrics.items():
accum_metrics.setdefault(k, []).append(v)
# Optimiser update
if self.global_step % self.grad_accum == 0:
nn.utils.clip_grad_norm_(
[p for p in self.model.parameters() if p.requires_grad],
self.max_grad_norm,
)
self.optimizer.step()
self.optimizer.zero_grad()
pbar.set_postfix({
"dpo": f"{metrics.get('dpo_loss', 0):.3f}",
"acc": f"{metrics.get('dpo_acc', 0):.3f}",
})
# Periodic validation
if self.global_step % self.val_every == 0:
val = self.validate()
avg = {k: float(np.mean(v)) for k, v in accum_metrics.items()}
logger.info(
f"Step {self.global_step:6d} | "
f"dpo_loss={avg.get('dpo_loss', 0):.3f} "
f"train_acc={avg.get('dpo_acc', 0):.3f} "
f"val_acc={val['val_dpo_acc']:.3f} "
f"margin={val['val_prob_margin']:.3f}"
)
if self.use_wandb:
wandb.log({**avg, **val, "step": self.global_step})
if val["val_dpo_acc"] > self.best_val_acc:
self.best_val_acc = val["val_dpo_acc"]
self.model.save_checkpoint(
str(self.exp_dir / "best"),
epoch=epoch, step=self.global_step,
)
logger.info(f" βœ… New best val_acc={self.best_val_acc:.4f}")
accum_metrics = {}
# Epoch-end validation
val = self.validate()
logger.info(
f"Epoch {epoch+1} end | "
f"val_acc={val['val_dpo_acc']:.3f} "
f"margin={val['val_prob_margin']:.3f} "
f"P(chosen)={val['val_prob_chosen']:.3f} "
f"P(rejected)={val['val_prob_rejected']:.3f}"
)
# Save epoch checkpoint
self.model.save_checkpoint(
str(self.exp_dir / f"epoch_{epoch+1}"),
epoch=epoch, step=self.global_step,
)
if val["val_dpo_acc"] > self.best_val_acc:
self.best_val_acc = val["val_dpo_acc"]
self.model.save_checkpoint(
str(self.exp_dir / "best"),
epoch=epoch, step=self.global_step,
)
logger.info(f"Training complete. Best val_dpo_acc={self.best_val_acc:.4f}")
# ─────────────────────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser("DPO trainer")
parser.add_argument("--sft_checkpoint", required=True,
help="Path to SFT best checkpoint dir")
parser.add_argument("--pair_dir", default="data/dpo_pairs")
parser.add_argument("--output_dir", default="checkpoints/DPO")
parser.add_argument("--experiment_name", default="dpo_v1")
parser.add_argument("--num_epochs", type=int, default=5)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--beta", type=float, default=0.1,
help="DPO temperature Ξ²")
parser.add_argument("--lambda_reg", type=float, default=0.5,
help="SFT regularisation weight")
parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
parser.add_argument("--val_every_n_steps",type=int, default=500)
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--debug_samples", type=int, default=64)
args = parser.parse_args()
pair_dir = Path(args.pair_dir)
train_manifests = [
pair_dir / "nexar_train.json",
pair_dir / "dada_train.json",
]
val_manifests = [
pair_dir / "nexar_val.json",
pair_dir / "dada_val.json",
]
train_ds = DPODataset(train_manifests, split="train",
debug=args.debug, debug_samples=args.debug_samples)
val_ds = DPODataset(val_manifests, split="val",
debug=args.debug, debug_samples=args.debug_samples // 4)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
collate_fn=dpo_collate_fn, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
collate_fn=dpo_collate_fn, num_workers=4, pin_memory=True)
model = DPOModel(sft_checkpoint_dir=args.sft_checkpoint, use_bf16=True)
trainer = DPOTrainer(
model = model,
train_loader = train_loader,
val_loader = val_loader,
output_dir = args.output_dir,
experiment_name = args.experiment_name,
num_epochs = args.num_epochs,
learning_rate = args.learning_rate,
beta = args.beta,
lambda_reg = args.lambda_reg,
gradient_accumulation_steps = args.gradient_accumulation_steps,
max_grad_norm = args.max_grad_norm,
val_every_n_steps= args.val_every_n_steps,
use_wandb = args.use_wandb,
)
trainer.train()
if __name__ == "__main__":
main()