Codeseys's picture
Wave 15: 4-angle multi-model self-critique caught 2 math BLOCKERs in primary loss kernels; fixed against upstream byte-for-byte + GSM8K example + ergonomics
e5add15
"""compose_loss.py — free 3-channel loss composer for verification smokes.
This is a verification-harness mirror of `ComposerReplicationTrainer._compute_loss`
that does NOT depend on TRL's GRPOTrainer parent. The GRPO channel is replaced
with standard LM next-token-prediction cross-entropy, which is the limit GRPO
converges to under deterministic rewards.
Use it for:
- CPU smokes on real HF models (Spike 006)
- Unit tests of loss composition without spinning up TRL
- Anywhere we want to verify gradient flow through the 3-channel sum
without paying TRL's full machinery cost
Do NOT use it as the production training loss. Production = ComposerReplicationTrainer
(a real GRPOTrainer subclass) which uses TRL's reward + advantage estimation.
Total loss:
total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo
Channels:
- lm_ce: standard cross-entropy on assistant-response tokens (GRPO stub)
- sdpo_jsd: generalized JSD between student and hint-conditioned-teacher logits
- trace_replay_dpo: DPO loss over (chosen, rejected) teacher-disagreement pairs
ADR-007 extensions
------------------
Three pluggable distillation losses can swap the default DPO/SDPO channels:
- ``dpo_variant="simpo"`` — channel 3 uses SimPO (reference-free DPO with
margin) instead of standard DPO. Reference logprobs are no longer required.
- ``sdpo_wrapper="taid"`` — channel 2 replaces SDPO with TAID (Temporally
Adaptive Interpolated Distillation, SakanaAI port). Requires ``taid_t``
(the current interpolation coefficient in ``[0, 1]``). The schedule that
produces ``taid_t`` is the trainer's responsibility — typically a
:class:`composer_replication.distillation.taid.TAIDScheduler` instance
driven by the per-step distillation loss.
- ``sdpo_wrapper="entropy_opd"`` — channel 2 uses Entropy-Aware OPD, a
per-token gated forward/reverse KL.
All three default to off; passing the new kwargs at their defaults is
bit-exact equivalent to the legacy 3-channel composition.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
import torch
import torch.nn.functional as F
from composer_replication.opsd import generalized_jsd_loss
@dataclass
class LossComponents:
"""Per-channel breakdown of the total loss for logging + ablation."""
lm_ce: torch.Tensor
sdpo_jsd: torch.Tensor
trace_replay_dpo: torch.Tensor
total: torch.Tensor
def detached(self) -> dict[str, float]:
return {
"lm_ce": float(self.lm_ce.detach()),
"sdpo_jsd": float(self.sdpo_jsd.detach()),
"trace_replay_dpo": float(self.trace_replay_dpo.detach()),
"total": float(self.total.detach()),
}
def compose_loss(
model: torch.nn.Module,
inputs: dict[str, torch.Tensor],
*,
alpha_sdpo: float = 0.1,
beta_replay: float = 0.05,
sdpo_jsd_beta: float = 0.5,
sdpo_temperature: float = 1.0,
sdpo_token_clip: float | None = None,
replay_dpo_beta: float = 0.1,
lm_ce_label_smoothing: float = 0.0,
# ADR-007 extensions ------------------------------------------------
dpo_variant: Literal["dpo", "simpo"] = "dpo",
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
taid_t: float | None = None,
# SimPO knobs (only used when dpo_variant="simpo") ------------------
simpo_beta: float = 2.0,
simpo_gamma: float = 1.0,
# Entropy-Aware OPD knobs (only used when sdpo_wrapper="entropy_opd")
entropy_opd_h_max: float | None = None,
) -> LossComponents:
"""Compute total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo.
Required keys in `inputs`:
- input_ids: (B, T_s) student rollout
- response_mask: (B, T_s) 1 on assistant-response tokens, 0 elsewhere
Optional keys (channel auto-disables if missing OR if its weight = 0):
SDPO:
- ctx_teacher_input_ids: (B, T_t) hint-conditioned context
- sdpo_loss_mask: (B, T_t) 1 at error-turn tokens
DPO (dpo_variant="dpo"):
- dpo_chosen_input_ids, dpo_chosen_response_mask
- dpo_rejected_input_ids, dpo_rejected_response_mask
- dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs (precomputed)
SimPO (dpo_variant="simpo"):
- dpo_chosen_input_ids, dpo_chosen_response_mask
- dpo_rejected_input_ids, dpo_rejected_response_mask
(reference logprobs not required and silently ignored)
TAID (sdpo_wrapper="taid"):
- taid_t kwarg: scalar float in [0, 1] giving the current
interpolation coefficient. The trainer is responsible for the
schedule (use TAIDScheduler from
composer_replication.distillation.taid for the paper-default
adaptive scheme, or any custom schedule of your choosing).
"""
if dpo_variant not in ("dpo", "simpo"):
raise ValueError(
f"dpo_variant must be 'dpo' or 'simpo', got {dpo_variant!r}"
)
if sdpo_wrapper not in ("none", "taid", "entropy_opd"):
raise ValueError(
f"sdpo_wrapper must be 'none', 'taid', or 'entropy_opd', "
f"got {sdpo_wrapper!r}"
)
if sdpo_wrapper == "taid":
if taid_t is None:
raise ValueError(
"sdpo_wrapper='taid' requires taid_t (float in [0, 1]). "
"Drive it from a TAIDScheduler or pass a fixed value."
)
if not (0.0 <= float(taid_t) <= 1.0):
raise ValueError(
f"taid_t must be in [0, 1], got {taid_t}"
)
device = _device_of(model)
# ------------------------------------------------------------------
# Channel 1 (GRPO stub): LM cross-entropy on response tokens
# ------------------------------------------------------------------
lm_ce = _lm_response_ce(
model,
inputs["input_ids"],
inputs["response_mask"],
label_smoothing=lm_ce_label_smoothing,
)
# ------------------------------------------------------------------
# Channel 2 (SDPO): generalized JSD on hint-conditioned forward
# Optionally wrapped by TAID or replaced by Entropy-Aware OPD.
# ------------------------------------------------------------------
sdpo_jsd = _zero(device)
if (
alpha_sdpo > 0.0
and "ctx_teacher_input_ids" in inputs
and inputs["ctx_teacher_input_ids"].numel() > 0
):
student_logits = model(input_ids=inputs["input_ids"]).logits
with torch.no_grad():
teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits
if student_logits.shape == teacher_logits.shape:
if sdpo_wrapper == "none":
sdpo_jsd = generalized_jsd_loss(
student_logits=student_logits,
teacher_logits=teacher_logits,
labels=inputs.get("sdpo_loss_mask"),
beta=sdpo_jsd_beta,
temperature=sdpo_temperature,
token_clip=sdpo_token_clip,
reduction="batchmean",
)
elif sdpo_wrapper == "taid":
from composer_replication.distillation import taid_loss
# taid_t validated non-None and in-range above.
assert taid_t is not None
# Reuse the SDPO loss-mask if provided so we only score the
# error-turn tokens; otherwise score all tokens.
taid_mask_bt = inputs.get("sdpo_loss_mask")
if taid_mask_bt is not None:
taid_mask_bt = taid_mask_bt.to(student_logits.device).float()
sdpo_jsd = taid_loss(
student_logits=student_logits,
teacher_logits=teacher_logits,
mask=taid_mask_bt,
t=float(taid_t),
)
elif sdpo_wrapper == "entropy_opd":
from composer_replication.distillation import (
entropy_aware_opd_loss,
)
sdpo_jsd = entropy_aware_opd_loss(
student_logits=student_logits,
teacher_logits=teacher_logits,
labels=inputs.get("sdpo_loss_mask"),
h_max=entropy_opd_h_max,
temperature=sdpo_temperature,
reduction="batchmean",
)
# else: silently zero — the data collator is responsible for shape
# alignment in production. For the smoke we accept misalignment and
# exercise the fallback path.
# ------------------------------------------------------------------
# Channel 3 (trace-replay DPO): standard DPO loss on teacher-disagreement
# pairs. With dpo_variant="simpo", swap to SimPO (reference-free).
# ------------------------------------------------------------------
trace_replay_dpo = _zero(device)
if (
beta_replay > 0.0
and "dpo_chosen_input_ids" in inputs
and inputs["dpo_chosen_input_ids"].numel() > 0
):
if dpo_variant == "dpo":
chosen_lp = _sequence_logprobs(
model,
inputs["dpo_chosen_input_ids"],
inputs["dpo_chosen_response_mask"],
)
rejected_lp = _sequence_logprobs(
model,
inputs["dpo_rejected_input_ids"],
inputs["dpo_rejected_response_mask"],
)
ref_chosen = inputs["dpo_chosen_ref_logprobs"]
ref_rejected = inputs["dpo_rejected_ref_logprobs"]
dpo_logits = replay_dpo_beta * (
(chosen_lp - ref_chosen) - (rejected_lp - ref_rejected)
)
trace_replay_dpo = -F.logsigmoid(dpo_logits).mean()
else: # dpo_variant == "simpo"
from composer_replication.distillation import simpo_loss
chosen_avg_lp = _avg_sequence_logprobs(
model,
inputs["dpo_chosen_input_ids"],
inputs["dpo_chosen_response_mask"],
)
rejected_avg_lp = _avg_sequence_logprobs(
model,
inputs["dpo_rejected_input_ids"],
inputs["dpo_rejected_response_mask"],
)
trace_replay_dpo = simpo_loss(
chosen_avg_logprobs=chosen_avg_lp,
rejected_avg_logprobs=rejected_avg_lp,
beta=simpo_beta,
gamma=simpo_gamma,
)
total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo
return LossComponents(
lm_ce=lm_ce,
sdpo_jsd=sdpo_jsd,
trace_replay_dpo=trace_replay_dpo,
total=total,
)
# ----------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------
def _zero(device: torch.device) -> torch.Tensor:
"""Differentiable zero — safe to add into a sum without breaking backward."""
return torch.zeros(1, device=device, requires_grad=True).squeeze()
def _device_of(model: torch.nn.Module) -> torch.device:
return next(model.parameters()).device
def _lm_response_ce(
model: torch.nn.Module,
input_ids: torch.Tensor,
response_mask: torch.Tensor,
*,
label_smoothing: float = 0.0,
) -> torch.Tensor:
"""Standard next-token-prediction cross-entropy on response tokens only.
Mirrors what GRPO converges to under deterministic rewards (the policy
gradient devolves to behavior cloning of high-reward rollouts).
"""
outputs = model(input_ids=input_ids)
# Shift: logits[t] predicts input_ids[t+1]
logits = outputs.logits[:, :-1, :]
targets = input_ids[:, 1:]
mask = response_mask[:, 1:].float()
loss_per_token = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
reduction="none",
label_smoothing=label_smoothing,
).view_as(targets)
masked = loss_per_token * mask
n_tokens = mask.sum().clamp_min(1.0)
return masked.sum() / n_tokens
def _sequence_logprobs(
model: torch.nn.Module,
input_ids: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor:
"""Sum of next-token logprobs over response tokens (standard DPO accounting)."""
outputs = model(input_ids=input_ids)
logits = outputs.logits[:, :-1, :]
targets = input_ids[:, 1:]
log_probs = F.log_softmax(logits, dim=-1)
token_lp = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
masked = token_lp * response_mask[:, 1:].float()
return masked.sum(dim=-1)
def _avg_sequence_logprobs(
model: torch.nn.Module,
input_ids: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor:
"""Per-sequence AVERAGE next-token logprob over response tokens.
SimPO accounting: divide the sum by the number of response tokens so
long sequences aren't penalized for length.
"""
outputs = model(input_ids=input_ids)
logits = outputs.logits[:, :-1, :]
targets = input_ids[:, 1:]
log_probs = F.log_softmax(logits, dim=-1)
token_lp = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
mask = response_mask[:, 1:].float()
masked = token_lp * mask
n_tokens = mask.sum(dim=-1).clamp_min(1.0)
return masked.sum(dim=-1) / n_tokens
__all__ = ["compose_loss", "LossComponents"]