Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """compose_loss.py — free 3-channel loss composer for verification smokes. | |
| This is a verification-harness mirror of `ComposerReplicationTrainer._compute_loss` | |
| that does NOT depend on TRL's GRPOTrainer parent. The GRPO channel is replaced | |
| with standard LM next-token-prediction cross-entropy, which is the limit GRPO | |
| converges to under deterministic rewards. | |
| Use it for: | |
| - CPU smokes on real HF models (Spike 006) | |
| - Unit tests of loss composition without spinning up TRL | |
| - Anywhere we want to verify gradient flow through the 3-channel sum | |
| without paying TRL's full machinery cost | |
| Do NOT use it as the production training loss. Production = ComposerReplicationTrainer | |
| (a real GRPOTrainer subclass) which uses TRL's reward + advantage estimation. | |
| Total loss: | |
| total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo | |
| Channels: | |
| - lm_ce: standard cross-entropy on assistant-response tokens (GRPO stub) | |
| - sdpo_jsd: generalized JSD between student and hint-conditioned-teacher logits | |
| - trace_replay_dpo: DPO loss over (chosen, rejected) teacher-disagreement pairs | |
| ADR-007 extensions | |
| ------------------ | |
| Three pluggable distillation losses can swap the default DPO/SDPO channels: | |
| - ``dpo_variant="simpo"`` — channel 3 uses SimPO (reference-free DPO with | |
| margin) instead of standard DPO. Reference logprobs are no longer required. | |
| - ``sdpo_wrapper="taid"`` — channel 2 replaces SDPO with TAID (Temporally | |
| Adaptive Interpolated Distillation, SakanaAI port). Requires ``taid_t`` | |
| (the current interpolation coefficient in ``[0, 1]``). The schedule that | |
| produces ``taid_t`` is the trainer's responsibility — typically a | |
| :class:`composer_replication.distillation.taid.TAIDScheduler` instance | |
| driven by the per-step distillation loss. | |
| - ``sdpo_wrapper="entropy_opd"`` — channel 2 uses Entropy-Aware OPD, a | |
| per-token gated forward/reverse KL. | |
| All three default to off; passing the new kwargs at their defaults is | |
| bit-exact equivalent to the legacy 3-channel composition. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Literal | |
| import torch | |
| import torch.nn.functional as F | |
| from composer_replication.opsd import generalized_jsd_loss | |
| class LossComponents: | |
| """Per-channel breakdown of the total loss for logging + ablation.""" | |
| lm_ce: torch.Tensor | |
| sdpo_jsd: torch.Tensor | |
| trace_replay_dpo: torch.Tensor | |
| total: torch.Tensor | |
| def detached(self) -> dict[str, float]: | |
| return { | |
| "lm_ce": float(self.lm_ce.detach()), | |
| "sdpo_jsd": float(self.sdpo_jsd.detach()), | |
| "trace_replay_dpo": float(self.trace_replay_dpo.detach()), | |
| "total": float(self.total.detach()), | |
| } | |
| def compose_loss( | |
| model: torch.nn.Module, | |
| inputs: dict[str, torch.Tensor], | |
| *, | |
| alpha_sdpo: float = 0.1, | |
| beta_replay: float = 0.05, | |
| sdpo_jsd_beta: float = 0.5, | |
| sdpo_temperature: float = 1.0, | |
| sdpo_token_clip: float | None = None, | |
| replay_dpo_beta: float = 0.1, | |
| lm_ce_label_smoothing: float = 0.0, | |
| # ADR-007 extensions ------------------------------------------------ | |
| dpo_variant: Literal["dpo", "simpo"] = "dpo", | |
| sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none", | |
| taid_t: float | None = None, | |
| # SimPO knobs (only used when dpo_variant="simpo") ------------------ | |
| simpo_beta: float = 2.0, | |
| simpo_gamma: float = 1.0, | |
| # Entropy-Aware OPD knobs (only used when sdpo_wrapper="entropy_opd") | |
| entropy_opd_h_max: float | None = None, | |
| ) -> LossComponents: | |
| """Compute total = lm_ce + alpha * sdpo_jsd + beta * trace_replay_dpo. | |
| Required keys in `inputs`: | |
| - input_ids: (B, T_s) student rollout | |
| - response_mask: (B, T_s) 1 on assistant-response tokens, 0 elsewhere | |
| Optional keys (channel auto-disables if missing OR if its weight = 0): | |
| SDPO: | |
| - ctx_teacher_input_ids: (B, T_t) hint-conditioned context | |
| - sdpo_loss_mask: (B, T_t) 1 at error-turn tokens | |
| DPO (dpo_variant="dpo"): | |
| - dpo_chosen_input_ids, dpo_chosen_response_mask | |
| - dpo_rejected_input_ids, dpo_rejected_response_mask | |
| - dpo_chosen_ref_logprobs, dpo_rejected_ref_logprobs (precomputed) | |
| SimPO (dpo_variant="simpo"): | |
| - dpo_chosen_input_ids, dpo_chosen_response_mask | |
| - dpo_rejected_input_ids, dpo_rejected_response_mask | |
| (reference logprobs not required and silently ignored) | |
| TAID (sdpo_wrapper="taid"): | |
| - taid_t kwarg: scalar float in [0, 1] giving the current | |
| interpolation coefficient. The trainer is responsible for the | |
| schedule (use TAIDScheduler from | |
| composer_replication.distillation.taid for the paper-default | |
| adaptive scheme, or any custom schedule of your choosing). | |
| """ | |
| if dpo_variant not in ("dpo", "simpo"): | |
| raise ValueError( | |
| f"dpo_variant must be 'dpo' or 'simpo', got {dpo_variant!r}" | |
| ) | |
| if sdpo_wrapper not in ("none", "taid", "entropy_opd"): | |
| raise ValueError( | |
| f"sdpo_wrapper must be 'none', 'taid', or 'entropy_opd', " | |
| f"got {sdpo_wrapper!r}" | |
| ) | |
| if sdpo_wrapper == "taid": | |
| if taid_t is None: | |
| raise ValueError( | |
| "sdpo_wrapper='taid' requires taid_t (float in [0, 1]). " | |
| "Drive it from a TAIDScheduler or pass a fixed value." | |
| ) | |
| if not (0.0 <= float(taid_t) <= 1.0): | |
| raise ValueError( | |
| f"taid_t must be in [0, 1], got {taid_t}" | |
| ) | |
| device = _device_of(model) | |
| # ------------------------------------------------------------------ | |
| # Channel 1 (GRPO stub): LM cross-entropy on response tokens | |
| # ------------------------------------------------------------------ | |
| lm_ce = _lm_response_ce( | |
| model, | |
| inputs["input_ids"], | |
| inputs["response_mask"], | |
| label_smoothing=lm_ce_label_smoothing, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Channel 2 (SDPO): generalized JSD on hint-conditioned forward | |
| # Optionally wrapped by TAID or replaced by Entropy-Aware OPD. | |
| # ------------------------------------------------------------------ | |
| sdpo_jsd = _zero(device) | |
| if ( | |
| alpha_sdpo > 0.0 | |
| and "ctx_teacher_input_ids" in inputs | |
| and inputs["ctx_teacher_input_ids"].numel() > 0 | |
| ): | |
| student_logits = model(input_ids=inputs["input_ids"]).logits | |
| with torch.no_grad(): | |
| teacher_logits = model(input_ids=inputs["ctx_teacher_input_ids"]).logits | |
| if student_logits.shape == teacher_logits.shape: | |
| if sdpo_wrapper == "none": | |
| sdpo_jsd = generalized_jsd_loss( | |
| student_logits=student_logits, | |
| teacher_logits=teacher_logits, | |
| labels=inputs.get("sdpo_loss_mask"), | |
| beta=sdpo_jsd_beta, | |
| temperature=sdpo_temperature, | |
| token_clip=sdpo_token_clip, | |
| reduction="batchmean", | |
| ) | |
| elif sdpo_wrapper == "taid": | |
| from composer_replication.distillation import taid_loss | |
| # taid_t validated non-None and in-range above. | |
| assert taid_t is not None | |
| # Reuse the SDPO loss-mask if provided so we only score the | |
| # error-turn tokens; otherwise score all tokens. | |
| taid_mask_bt = inputs.get("sdpo_loss_mask") | |
| if taid_mask_bt is not None: | |
| taid_mask_bt = taid_mask_bt.to(student_logits.device).float() | |
| sdpo_jsd = taid_loss( | |
| student_logits=student_logits, | |
| teacher_logits=teacher_logits, | |
| mask=taid_mask_bt, | |
| t=float(taid_t), | |
| ) | |
| elif sdpo_wrapper == "entropy_opd": | |
| from composer_replication.distillation import ( | |
| entropy_aware_opd_loss, | |
| ) | |
| sdpo_jsd = entropy_aware_opd_loss( | |
| student_logits=student_logits, | |
| teacher_logits=teacher_logits, | |
| labels=inputs.get("sdpo_loss_mask"), | |
| h_max=entropy_opd_h_max, | |
| temperature=sdpo_temperature, | |
| reduction="batchmean", | |
| ) | |
| # else: silently zero — the data collator is responsible for shape | |
| # alignment in production. For the smoke we accept misalignment and | |
| # exercise the fallback path. | |
| # ------------------------------------------------------------------ | |
| # Channel 3 (trace-replay DPO): standard DPO loss on teacher-disagreement | |
| # pairs. With dpo_variant="simpo", swap to SimPO (reference-free). | |
| # ------------------------------------------------------------------ | |
| trace_replay_dpo = _zero(device) | |
| if ( | |
| beta_replay > 0.0 | |
| and "dpo_chosen_input_ids" in inputs | |
| and inputs["dpo_chosen_input_ids"].numel() > 0 | |
| ): | |
| if dpo_variant == "dpo": | |
| chosen_lp = _sequence_logprobs( | |
| model, | |
| inputs["dpo_chosen_input_ids"], | |
| inputs["dpo_chosen_response_mask"], | |
| ) | |
| rejected_lp = _sequence_logprobs( | |
| model, | |
| inputs["dpo_rejected_input_ids"], | |
| inputs["dpo_rejected_response_mask"], | |
| ) | |
| ref_chosen = inputs["dpo_chosen_ref_logprobs"] | |
| ref_rejected = inputs["dpo_rejected_ref_logprobs"] | |
| dpo_logits = replay_dpo_beta * ( | |
| (chosen_lp - ref_chosen) - (rejected_lp - ref_rejected) | |
| ) | |
| trace_replay_dpo = -F.logsigmoid(dpo_logits).mean() | |
| else: # dpo_variant == "simpo" | |
| from composer_replication.distillation import simpo_loss | |
| chosen_avg_lp = _avg_sequence_logprobs( | |
| model, | |
| inputs["dpo_chosen_input_ids"], | |
| inputs["dpo_chosen_response_mask"], | |
| ) | |
| rejected_avg_lp = _avg_sequence_logprobs( | |
| model, | |
| inputs["dpo_rejected_input_ids"], | |
| inputs["dpo_rejected_response_mask"], | |
| ) | |
| trace_replay_dpo = simpo_loss( | |
| chosen_avg_logprobs=chosen_avg_lp, | |
| rejected_avg_logprobs=rejected_avg_lp, | |
| beta=simpo_beta, | |
| gamma=simpo_gamma, | |
| ) | |
| total = lm_ce + alpha_sdpo * sdpo_jsd + beta_replay * trace_replay_dpo | |
| return LossComponents( | |
| lm_ce=lm_ce, | |
| sdpo_jsd=sdpo_jsd, | |
| trace_replay_dpo=trace_replay_dpo, | |
| total=total, | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # Helpers | |
| # ---------------------------------------------------------------------- | |
| def _zero(device: torch.device) -> torch.Tensor: | |
| """Differentiable zero — safe to add into a sum without breaking backward.""" | |
| return torch.zeros(1, device=device, requires_grad=True).squeeze() | |
| def _device_of(model: torch.nn.Module) -> torch.device: | |
| return next(model.parameters()).device | |
| def _lm_response_ce( | |
| model: torch.nn.Module, | |
| input_ids: torch.Tensor, | |
| response_mask: torch.Tensor, | |
| *, | |
| label_smoothing: float = 0.0, | |
| ) -> torch.Tensor: | |
| """Standard next-token-prediction cross-entropy on response tokens only. | |
| Mirrors what GRPO converges to under deterministic rewards (the policy | |
| gradient devolves to behavior cloning of high-reward rollouts). | |
| """ | |
| outputs = model(input_ids=input_ids) | |
| # Shift: logits[t] predicts input_ids[t+1] | |
| logits = outputs.logits[:, :-1, :] | |
| targets = input_ids[:, 1:] | |
| mask = response_mask[:, 1:].float() | |
| loss_per_token = F.cross_entropy( | |
| logits.reshape(-1, logits.size(-1)), | |
| targets.reshape(-1), | |
| reduction="none", | |
| label_smoothing=label_smoothing, | |
| ).view_as(targets) | |
| masked = loss_per_token * mask | |
| n_tokens = mask.sum().clamp_min(1.0) | |
| return masked.sum() / n_tokens | |
| def _sequence_logprobs( | |
| model: torch.nn.Module, | |
| input_ids: torch.Tensor, | |
| response_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Sum of next-token logprobs over response tokens (standard DPO accounting).""" | |
| outputs = model(input_ids=input_ids) | |
| logits = outputs.logits[:, :-1, :] | |
| targets = input_ids[:, 1:] | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| token_lp = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) | |
| masked = token_lp * response_mask[:, 1:].float() | |
| return masked.sum(dim=-1) | |
| def _avg_sequence_logprobs( | |
| model: torch.nn.Module, | |
| input_ids: torch.Tensor, | |
| response_mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Per-sequence AVERAGE next-token logprob over response tokens. | |
| SimPO accounting: divide the sum by the number of response tokens so | |
| long sequences aren't penalized for length. | |
| """ | |
| outputs = model(input_ids=input_ids) | |
| logits = outputs.logits[:, :-1, :] | |
| targets = input_ids[:, 1:] | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| token_lp = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1) | |
| mask = response_mask[:, 1:].float() | |
| masked = token_lp * mask | |
| n_tokens = mask.sum(dim=-1).clamp_min(1.0) | |
| return masked.sum(dim=-1) / n_tokens | |
| __all__ = ["compose_loss", "LossComponents"] | |