File size: 13,432 Bytes
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9dd3a5
 
 
 
 
 
 
e5add15
 
 
 
 
 
d9dd3a5
 
 
 
 
ac05fbf
 
 
 
d9dd3a5
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9dd3a5
 
 
e5add15
d9dd3a5
 
 
 
 
ac05fbf
 
 
 
 
 
 
 
 
 
 
d9dd3a5
ac05fbf
 
 
d9dd3a5
 
 
 
 
e5add15
 
 
 
 
ac05fbf
d9dd3a5
 
 
 
 
 
 
 
 
 
e5add15
d9dd3a5
e5add15
 
d9dd3a5
e5add15
d9dd3a5
e5add15
d9dd3a5
 
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
d9dd3a5
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
e5add15
 
 
 
 
 
 
d9dd3a5
 
 
e5add15
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
 
 
 
 
 
d9dd3a5
ac05fbf
 
 
 
 
 
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9dd3a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac05fbf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
"""compose_loss.py — free 3-channel loss composer for verification smokes.

This is a verification-harness mirror of `ComposerReplicationTrainer._compute_loss`
that does NOT depend on TRL's GRPOTrainer parent. The GRPO channel is replaced
with standard LM next-token-prediction cross-entropy, which is the limit GRPO
converges to under deterministic rewards.

Use it for:
- CPU smokes on real HF models (Spike 006)
- Unit tests of loss composition without spinning up TRL
- Anywhere we want to verify gradient flow through the 3-channel sum
  without paying TRL's full machinery cost

Do NOT use it as the production training loss. Production = ComposerReplicationTrainer
(a real GRPOTrainer subclass) which uses TRL's reward + advantage estimation.

Total loss:
    total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo

Channels:
- lm_ce: standard cross-entropy on assistant-response tokens (GRPO stub)
- sdpo_jsd: generalized JSD between student and hint-conditioned-teacher logits
- trace_replay_dpo: DPO loss over (chosen, rejected) teacher-disagreement pairs

ADR-007 extensions
------------------
Three pluggable distillation losses can swap the default DPO/SDPO channels:

- ``dpo_variant="simpo"`` — channel 3 uses SimPO (reference-free DPO with
  margin) instead of standard DPO. Reference logprobs are no longer required.
- ``sdpo_wrapper="taid"`` — channel 2 replaces SDPO with TAID (Temporally
  Adaptive Interpolated Distillation, SakanaAI port). Requires ``taid_t``
  (the current interpolation coefficient in ``[0, 1]``). The schedule that
  produces ``taid_t`` is the trainer's responsibility — typically a
  :class:`composer_replication.distillation.taid.TAIDScheduler` instance
  driven by the per-step distillation loss.
- ``sdpo_wrapper="entropy_opd"`` — channel 2 uses Entropy-Aware OPD, a
  per-token gated forward/reverse KL.

All three default to off; passing the new kwargs at their defaults is
bit-exact equivalent to the legacy 3-channel composition.
"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

import torch
import torch.nn.functional as F

from composer_replication.opsd import generalized_jsd_loss


@dataclass
class LossComponents:
    """Per-channel breakdown of the total loss for logging + ablation."""
    lm_ce: torch.Tensor
    sdpo_jsd: torch.Tensor
    trace_replay_dpo: torch.Tensor
    total: torch.Tensor

    def detached(self) -> dict[str, float]:
        return {
            "lm_ce": float(self.lm_ce.detach()),
            "sdpo_jsd": float(self.sdpo_jsd.detach()),
            "trace_replay_dpo": float(self.trace_replay_dpo.detach()),
            "total": float(self.total.detach()),
        }


def compose_loss(
    model: torch.nn.Module,
    inputs: dict[str, torch.Tensor],
    *,
    alpha_sdpo: float = 0.1,
    beta_replay: float = 0.05,
    sdpo_jsd_beta: float = 0.5,
    sdpo_temperature: float = 1.0,
    sdpo_token_clip: float | None = None,
    replay_dpo_beta: float = 0.1,
    lm_ce_label_smoothing: float = 0.0,
    # ADR-007 extensions ------------------------------------------------
    dpo_variant: Literal["dpo", "simpo"] = "dpo",
    sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
    taid_t: float | None = None,
    # SimPO knobs (only used when dpo_variant="simpo") ------------------
    simpo_beta: float = 2.0,
    simpo_gamma: float = 1.0,
    # Entropy-Aware OPD knobs (only used when sdpo_wrapper="entropy_opd")
    entropy_opd_h_max: float | None = None,
) -> LossComponents:
    """Compute total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo.

    Required keys in `inputs`:
        - input_ids: (B, T_s) student rollout
        - response_mask: (B, T_s) 1 on assistant-response tokens, 0 elsewhere

    Optional keys (channel auto-disables if missing OR if its weight = 0):
        SDPO:
        - ctx_teacher_input_ids: (B, T_t) hint-conditioned context
        - sdpo_loss_mask: (B, T_t) 1 at error-turn tokens
        DPO (dpo_variant="dpo"):
        - dpo_chosen_input_ids, dpo_chosen_response_mask
        - dpo_rejected_input_ids, dpo_rejected_response_mask
        - dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs (precomputed)
        SimPO (dpo_variant="simpo"):
        - dpo_chosen_input_ids, dpo_chosen_response_mask
        - dpo_rejected_input_ids, dpo_rejected_response_mask
        (reference logprobs not required and silently ignored)
        TAID (sdpo_wrapper="taid"):
        - taid_t kwarg: scalar float in [0, 1] giving the current
          interpolation coefficient. The trainer is responsible for the
          schedule (use TAIDScheduler from
          composer_replication.distillation.taid for the paper-default
          adaptive scheme, or any custom schedule of your choosing).
    """
    if dpo_variant not in ("dpo", "simpo"):
        raise ValueError(
            f"dpo_variant must be 'dpo' or 'simpo', got {dpo_variant!r}"
        )
    if sdpo_wrapper not in ("none", "taid", "entropy_opd"):
        raise ValueError(
            f"sdpo_wrapper must be 'none', 'taid', or 'entropy_opd', "
            f"got {sdpo_wrapper!r}"
        )
    if sdpo_wrapper == "taid":
        if taid_t is None:
            raise ValueError(
                "sdpo_wrapper='taid' requires taid_t (float in [0, 1]). "
                "Drive it from a TAIDScheduler or pass a fixed value."
            )
        if not (0.0 <= float(taid_t) <= 1.0):
            raise ValueError(
                f"taid_t must be in [0, 1], got {taid_t}"
            )

    device = _device_of(model)

    # ------------------------------------------------------------------
    # Channel 1 (GRPO stub): LM cross-entropy on response tokens
    # ------------------------------------------------------------------
    lm_ce = _lm_response_ce(
        model,
        inputs["input_ids"],
        inputs["response_mask"],
        label_smoothing=lm_ce_label_smoothing,
    )

    # ------------------------------------------------------------------
    # Channel 2 (SDPO): generalized JSD on hint-conditioned forward
    # Optionally wrapped by TAID or replaced by Entropy-Aware OPD.
    # ------------------------------------------------------------------
    sdpo_jsd = _zero(device)
    if (
        alpha_sdpo > 0.0
        and "ctx_teacher_input_ids" in inputs
        and inputs["ctx_teacher_input_ids"].numel() > 0
    ):
        student_logits = model(input_ids=inputs["input_ids"]).logits
        with torch.no_grad():
            teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits

        if student_logits.shape == teacher_logits.shape:
            if sdpo_wrapper == "none":
                sdpo_jsd = generalized_jsd_loss(
                    student_logits=student_logits,
                    teacher_logits=teacher_logits,
                    labels=inputs.get("sdpo_loss_mask"),
                    beta=sdpo_jsd_beta,
                    temperature=sdpo_temperature,
                    token_clip=sdpo_token_clip,
                    reduction="batchmean",
                )
            elif sdpo_wrapper == "taid":
                from composer_replication.distillation import taid_loss

                # taid_t validated non-None and in-range above.
                assert taid_t is not None
                # Reuse the SDPO loss-mask if provided so we only score the
                # error-turn tokens; otherwise score all tokens.
                taid_mask_bt = inputs.get("sdpo_loss_mask")
                if taid_mask_bt is not None:
                    taid_mask_bt = taid_mask_bt.to(student_logits.device).float()
                sdpo_jsd = taid_loss(
                    student_logits=student_logits,
                    teacher_logits=teacher_logits,
                    mask=taid_mask_bt,
                    t=float(taid_t),
                )
            elif sdpo_wrapper == "entropy_opd":
                from composer_replication.distillation import (
                    entropy_aware_opd_loss,
                )

                sdpo_jsd = entropy_aware_opd_loss(
                    student_logits=student_logits,
                    teacher_logits=teacher_logits,
                    labels=inputs.get("sdpo_loss_mask"),
                    h_max=entropy_opd_h_max,
                    temperature=sdpo_temperature,
                    reduction="batchmean",
                )
        # else: silently zero — the data collator is responsible for shape
        # alignment in production. For the smoke we accept misalignment and
        # exercise the fallback path.

    # ------------------------------------------------------------------
    # Channel 3 (trace-replay DPO): standard DPO loss on teacher-disagreement
    # pairs. With dpo_variant="simpo", swap to SimPO (reference-free).
    # ------------------------------------------------------------------
    trace_replay_dpo = _zero(device)
    if (
        beta_replay > 0.0
        and "dpo_chosen_input_ids" in inputs
        and inputs["dpo_chosen_input_ids"].numel() > 0
    ):
        if dpo_variant == "dpo":
            chosen_lp = _sequence_logprobs(
                model,
                inputs["dpo_chosen_input_ids"],
                inputs["dpo_chosen_response_mask"],
            )
            rejected_lp = _sequence_logprobs(
                model,
                inputs["dpo_rejected_input_ids"],
                inputs["dpo_rejected_response_mask"],
            )
            ref_chosen = inputs["dpo_chosen_ref_logprobs"]
            ref_rejected = inputs["dpo_rejected_ref_logprobs"]
            dpo_logits = replay_dpo_beta * (
                (chosen_lp - ref_chosen) - (rejected_lp - ref_rejected)
            )
            trace_replay_dpo = -F.logsigmoid(dpo_logits).mean()
        else:  # dpo_variant == "simpo"
            from composer_replication.distillation import simpo_loss

            chosen_avg_lp = _avg_sequence_logprobs(
                model,
                inputs["dpo_chosen_input_ids"],
                inputs["dpo_chosen_response_mask"],
            )
            rejected_avg_lp = _avg_sequence_logprobs(
                model,
                inputs["dpo_rejected_input_ids"],
                inputs["dpo_rejected_response_mask"],
            )
            trace_replay_dpo = simpo_loss(
                chosen_avg_logprobs=chosen_avg_lp,
                rejected_avg_logprobs=rejected_avg_lp,
                beta=simpo_beta,
                gamma=simpo_gamma,
            )

    total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo

    return LossComponents(
        lm_ce=lm_ce,
        sdpo_jsd=sdpo_jsd,
        trace_replay_dpo=trace_replay_dpo,
        total=total,
    )


# ----------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------

def _zero(device: torch.device) -> torch.Tensor:
    """Differentiable zero — safe to add into a sum without breaking backward."""
    return torch.zeros(1, device=device, requires_grad=True).squeeze()


def _device_of(model: torch.nn.Module) -> torch.device:
    return next(model.parameters()).device


def _lm_response_ce(
    model: torch.nn.Module,
    input_ids: torch.Tensor,
    response_mask: torch.Tensor,
    *,
    label_smoothing: float = 0.0,
) -> torch.Tensor:
    """Standard next-token-prediction cross-entropy on response tokens only.

    Mirrors what GRPO converges to under deterministic rewards (the policy
    gradient devolves to behavior cloning of high-reward rollouts).
    """
    outputs = model(input_ids=input_ids)
    # Shift: logits[t] predicts input_ids[t+1]
    logits = outputs.logits[:, :-1, :]
    targets = input_ids[:, 1:]
    mask = response_mask[:, 1:].float()

    loss_per_token = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        targets.reshape(-1),
        reduction="none",
        label_smoothing=label_smoothing,
    ).view_as(targets)

    masked = loss_per_token * mask
    n_tokens = mask.sum().clamp_min(1.0)
    return masked.sum() / n_tokens


def _sequence_logprobs(
    model: torch.nn.Module,
    input_ids: torch.Tensor,
    response_mask: torch.Tensor,
) -> torch.Tensor:
    """Sum of next-token logprobs over response tokens (standard DPO accounting)."""
    outputs = model(input_ids=input_ids)
    logits = outputs.logits[:, :-1, :]
    targets = input_ids[:, 1:]
    log_probs = F.log_softmax(logits, dim=-1)
    token_lp = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
    masked = token_lp * response_mask[:, 1:].float()
    return masked.sum(dim=-1)


def _avg_sequence_logprobs(
    model: torch.nn.Module,
    input_ids: torch.Tensor,
    response_mask: torch.Tensor,
) -> torch.Tensor:
    """Per-sequence AVERAGE next-token logprob over response tokens.

    SimPO accounting: divide the sum by the number of response tokens so
    long sequences aren't penalized for length.
    """
    outputs = model(input_ids=input_ids)
    logits = outputs.logits[:, :-1, :]
    targets = input_ids[:, 1:]
    log_probs = F.log_softmax(logits, dim=-1)
    token_lp = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
    mask = response_mask[:, 1:].float()
    masked = token_lp * mask
    n_tokens = mask.sum(dim=-1).clamp_min(1.0)
    return masked.sum(dim=-1) / n_tokens


__all__ = ["compose_loss", "LossComponents"]