Codeseys's picture
Wave 19: production-grade SDPO via ComposerDataCollator + adapter + collator fixes
03bf323
"""data_collator.py — ComposerDataCollator: raw trace → trainer-ready batch.
Pipeline:
1. Take a frozen agentic trace + N-teacher DPO pairs (from spike 002 + 003).
2. Tokenize each turn of the trace.
3. Detect error sites (turns where a tool call failed) using a configurable predicate.
4. At each error site, build ctx_teacher = ctx_student with hint inserted at the error-turn boundary.
5. Pad/align ctx_student and ctx_teacher so SDPO logits compare position-by-position.
6. Construct sdpo_loss_mask = 1 at post-hint tokens of the error turn, 0 elsewhere.
7. Tokenize DPO chosen/rejected pairs, build response masks, leave ref_logprobs as a precompute step.
The output dict is what `ComposerReplicationTrainer._compute_loss` expects in its
`inputs` argument. See `trl_path/composer_trainer.py` for the consumer side.
Architectural note (verified via spike 005 test_opsd_loss.py): generalized_jsd_loss
requires student_logits and teacher_logits to have the SAME (B, T, V) shape — that's
why we pad/align here rather than inside the loss function. The post-hint section of
ctx_teacher must have token-by-token alignment with the same section of ctx_student.
"""
from __future__ import annotations
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import Any, TypedDict
import torch
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
class TraceTurn(TypedDict, total=False):
"""One turn of an agentic trace."""
role: str # "user" | "assistant" | "tool"
content: str # text or tool result
tool_call: dict | None # parsed tool call, if assistant-issued
tool_error: str | None # error_kind from the env, e.g. "tool_not_found"
error_meta: dict # extra info for hint generator (available_tools, etc.)
class TraceExample(TypedDict, total=False):
"""One training example: a (trace, optional DPO pairs) tuple."""
trace_id: str
turns: list[TraceTurn]
final_reward: float # RLVR scalar (test-pass etc.) at trajectory end
dpo_pairs: list[dict] | None # from teacher_replay.extract_dpo_pairs
# ---------------------------------------------------------------------------
# Tokenizer protocol — duck-typed against HF AutoTokenizer
# ---------------------------------------------------------------------------
class TokenizerLike:
"""Minimal protocol the collator needs from a tokenizer.
Compatible with HuggingFace `AutoTokenizer` instances (the typical case),
but also satisfiable by simpler stubs for unit-testing.
"""
pad_token_id: int
def __call__(self, text: str | list[str], **kwargs: Any) -> dict[str, list]: # pragma: no cover
...
def apply_chat_template( # pragma: no cover
self, messages: list[dict], **kwargs: Any
) -> str | list[int]:
...
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class CollatorConfig:
"""Tunables for ComposerDataCollator."""
max_seq_len: int = 4096
max_dpo_seq_len: int = 2048
pad_token_id: int = 0
ignore_index: int = -100 # standard HF "ignore in loss" sentinel
# SDPO behavior
enable_sdpo: bool = True
hint_generator: Callable[[str, dict], str | None] | None = None
"""Callable error_kind, error_meta -> hint_text (or None to skip)."""
# Trace-replay DPO behavior
enable_replay_dpo: bool = True
# Reward shaping
rlvr_reward_key: str = "final_reward"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _is_error_turn(turn: TraceTurn) -> bool:
"""Predicate: is this turn an error site that should trigger SDPO?"""
return turn.get("tool_error") is not None
def _build_chat_messages(turns: Sequence[TraceTurn]) -> list[dict]:
"""Convert TraceTurns to OpenAI-style chat messages for tokenizer.apply_chat_template."""
return [
{"role": t["role"], "content": t["content"]}
for t in turns if t.get("content")
]
def _pad_or_truncate(seq: list[int], target_len: int, pad_id: int) -> list[int]:
"""Right-pad with pad_id, or right-truncate to target_len."""
if len(seq) >= target_len:
return seq[:target_len]
return seq + [pad_id] * (target_len - len(seq))
# ---------------------------------------------------------------------------
# The collator
# ---------------------------------------------------------------------------
@dataclass
class ComposerDataCollator:
"""Build trainer-ready batches from raw traces + optional DPO pairs.
Usage:
collator = ComposerDataCollator(tokenizer=tok, config=CollatorConfig())
batch = collator([trace_example_0, trace_example_1, ...])
# batch is a dict[str, torch.Tensor] ready for ComposerReplicationTrainer
The dict contains:
# Channel 1 (GRPO/RLVR — handled by the parent GRPOTrainer)
- input_ids: (B, T_max)
- attention_mask: (B, T_max)
- response_mask: (B, T_max)
- rewards: (B,)
# Channel 2 (SDPO hint-distill) — present when any example has error turns
- ctx_teacher_input_ids: (B, T_max)
- sdpo_loss_mask: (B, T_max), 1 at post-hint error-turn tokens
# Channel 3 (trace-replay DPO) — present when any example has dpo_pairs
- dpo_chosen_input_ids: (B', T_dpo)
- dpo_chosen_response_mask: (B', T_dpo)
- dpo_rejected_input_ids: (B', T_dpo)
- dpo_rejected_response_mask: (B', T_dpo)
# ref_logprobs are NOT computed here — the trainer's reference-policy
# forward pass at training time produces them.
"""
tokenizer: TokenizerLike
config: CollatorConfig = field(default_factory=CollatorConfig)
def __call__(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]:
out: dict[str, torch.Tensor] = {}
# --- Channel 1: GRPO core fields ---
out.update(self._build_grpo_fields(batch))
# --- Channel 2: SDPO hint-distill fields ---
if self.config.enable_sdpo:
sdpo = self._build_sdpo_fields(batch)
if sdpo is not None:
out.update(sdpo)
# Reconcile student vs teacher shapes for compose_loss's
# `student_logits.shape == teacher_logits.shape` gate.
#
# CRITICAL: hint injection adds tokens IN THE MIDDLE of
# the teacher sequence (before the recovery turn). The
# recovery turn lives at teacher positions
# [hint_end .. hint_end + len(recovery)] but at student
# positions [recovery_start .. recovery_start + len(recovery)]
# where recovery_start < hint_end. Right-padding student
# to teacher length WOULD ALIAS PAD TOKENS to the
# sdpo_loss_mask region — gives a degenerate ~ln(2)
# JSD signal that LOOKS healthy but is meaningless
# (Gemini W19 R1 BLOCKER).
#
# Correct alignment requires walking turns in lock-step,
# padding student WHERE the teacher has hint tokens so
# post-hint positions land at the same indices in both.
# That reshape lives in `_build_aligned_student_for_sdpo`.
aligned = self._build_aligned_student_for_sdpo(
batch, teacher_len=out["ctx_teacher_input_ids"].shape[1]
)
if aligned is not None:
out["input_ids"] = aligned["input_ids"]
out["attention_mask"] = aligned["attention_mask"]
out["response_mask"] = aligned["response_mask"]
# --- Channel 3: trace-replay DPO fields ---
if self.config.enable_replay_dpo:
dpo = self._build_dpo_fields(batch)
if dpo is not None:
out.update(dpo)
return out
# ----------------------------------------------------------------------
# Channel 1: standard GRPO inputs
# ----------------------------------------------------------------------
def _build_grpo_fields(self, batch: Sequence[TraceExample]) -> dict[str, torch.Tensor]:
input_ids_list: list[list[int]] = []
response_masks_list: list[list[int]] = []
rewards: list[float] = []
for ex in batch:
ids, resp_mask = self._tokenize_trace(ex["turns"])
input_ids_list.append(ids)
response_masks_list.append(resp_mask)
rewards.append(float(ex.get(self.config.rlvr_reward_key, 0.0)))
max_len = min(self.config.max_seq_len, max(len(s) for s in input_ids_list))
input_ids = torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in input_ids_list],
dtype=torch.long,
)
response_mask = torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in response_masks_list],
dtype=torch.long,
)
attention_mask = (input_ids != self.config.pad_token_id).long()
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"response_mask": response_mask,
"rewards": torch.tensor(rewards, dtype=torch.float),
}
# ----------------------------------------------------------------------
# Channel 2: SDPO hint-distill inputs
# ----------------------------------------------------------------------
def _build_sdpo_fields(
self, batch: Sequence[TraceExample]
) -> dict[str, torch.Tensor] | None:
"""Build ctx_teacher + sdpo_loss_mask, aligned to ctx_student length."""
if self.config.hint_generator is None:
return None # nothing to do without a hint generator
ctx_teacher_list: list[list[int]] = []
sdpo_mask_list: list[list[int]] = []
any_error_sites = False
for ex in batch:
ctx_teacher_ids, sdpo_mask, has_errors = self._build_hint_injected_trace(ex["turns"])
ctx_teacher_list.append(ctx_teacher_ids)
sdpo_mask_list.append(sdpo_mask)
any_error_sites = any_error_sites or has_errors
if not any_error_sites:
return None # batch has no error sites — SDPO is a no-op for this step
max_len = min(self.config.max_seq_len, max(len(s) for s in ctx_teacher_list))
ctx_teacher = torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in ctx_teacher_list],
dtype=torch.long,
)
sdpo_mask = torch.tensor(
[_pad_or_truncate(m, max_len, self.config.ignore_index) for m in sdpo_mask_list],
dtype=torch.long,
)
return {
"ctx_teacher_input_ids": ctx_teacher,
"sdpo_loss_mask": sdpo_mask,
}
def _build_hint_injected_trace(
self, turns: Sequence[TraceTurn]
) -> tuple[list[int], list[int], bool]:
"""Walk the trace; at each error-turn boundary, inject a hint and mark
the post-hint tokens as in-loss.
Returns:
(ctx_teacher_ids, sdpo_loss_mask, any_error_sites)
"""
if self.config.hint_generator is None:
# Caller responsibility — short-circuited by the dispatch.
empty: list[int] = []
return empty, empty, False
teacher_messages: list[dict] = []
teacher_loss_segments: list[tuple[bool, str]] = [] # (is_loss_segment, text)
any_errors = False
for turn in turns:
if _is_error_turn(turn):
hint_text = self.config.hint_generator(
turn.get("tool_error", "unknown"),
turn.get("error_meta", {}),
)
if hint_text:
any_errors = True
# Inject hint as a system-style addendum BEFORE the assistant's response
teacher_messages.append({"role": "system", "content": hint_text})
teacher_loss_segments.append((False, hint_text))
if turn.get("content"):
teacher_messages.append({
"role": turn.get("role", "assistant"),
"content": turn["content"],
})
teacher_loss_segments.append((True, turn["content"])) # post-hint tokens = loss
continue
# Non-error turn (or hint generator returned None) — passthrough
if turn.get("content"):
teacher_messages.append({
"role": turn.get("role", "assistant"),
"content": turn["content"],
})
teacher_loss_segments.append((False, turn["content"]))
# Tokenize the full teacher conversation
teacher_ids = self._tokenize_messages(teacher_messages)
# Build the per-token loss mask by tokenizing each segment and concatenating
sdpo_mask = self._build_segment_mask(teacher_loss_segments)
# Truncate mask to teacher_ids length if tokenization round-tripped slightly differently
sdpo_mask = sdpo_mask[: len(teacher_ids)]
if len(sdpo_mask) < len(teacher_ids):
sdpo_mask = sdpo_mask + [self.config.ignore_index] * (len(teacher_ids) - len(sdpo_mask))
return teacher_ids, sdpo_mask, any_errors
def _build_aligned_student_for_sdpo(
self,
batch: Sequence[TraceExample],
teacher_len: int,
) -> dict[str, torch.Tensor] | None:
"""Build student input_ids that align position-by-position with the
hint-injected teacher sequence.
For SDPO the gate `student_logits.shape == teacher_logits.shape`
must pass AND the sdpo_loss_mask positions (built relative to the
teacher) must point to the SAME content tokens in the student.
Strategy: build student MESSAGES that mirror the teacher messages
EXCEPT the hint system-message is replaced with a placeholder
system-message whose `content` tokenizes to the same length as
the hint. Both sides go through `apply_chat_template`, so the
chat-template markers (<|im_start|>system\\n, <|im_end|>\\n, etc.)
are added identically. The recovery-turn tokens then land at the
same indices in both tensors and `sdpo_loss_mask` selects
identical content positions.
Returns None if no error sites exist.
"""
if self.config.hint_generator is None:
return None
student_ids_list: list[list[int]] = []
response_mask_list: list[list[int]] = []
any_errors = False
for ex in batch:
ids, resp_mask, has_errors = self._build_aligned_student_one(ex["turns"])
student_ids_list.append(ids)
response_mask_list.append(resp_mask)
any_errors = any_errors or has_errors
if not any_errors:
return None
max_len = teacher_len # match teacher exactly
pad_id = self.config.pad_token_id
input_ids = torch.tensor(
[_pad_or_truncate(s, max_len, pad_id) for s in student_ids_list],
dtype=torch.long,
)
response_mask = torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in response_mask_list],
dtype=torch.long,
)
attention_mask = (input_ids != pad_id).long()
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"response_mask": response_mask,
}
def _make_placeholder_for_hint_length(self, hint_text: str) -> str:
"""Build a placeholder string whose tokenization length matches hint_text's.
We start with a short repeating filler ('. ') and grow it until the
tokenized length matches or exceeds the hint's. If we overshoot,
we trim. This is necessarily approximate at the character-to-token
boundary; we accept ±1 token tolerance and pad/truncate the final
student tensor to match teacher length.
"""
target_len = len(self._tokenize_text(hint_text))
if target_len == 0:
return ""
# Use a content-free placeholder that tokenizes predictably.
placeholder = ". " * target_len
ph_len = len(self._tokenize_text(placeholder))
# Trim or extend via binary-search-ish refinement (at most 6 iters).
for _ in range(6):
if ph_len == target_len:
break
if ph_len > target_len:
# Trim char-by-char
while placeholder and ph_len > target_len:
placeholder = placeholder[:-1]
ph_len = len(self._tokenize_text(placeholder))
else:
placeholder = placeholder + ". "
ph_len = len(self._tokenize_text(placeholder))
return placeholder
def _build_aligned_student_one(
self, turns: Sequence[TraceTurn]
) -> tuple[list[int], list[int], bool]:
"""Walk one trace's turns, building a STUDENT messages list that
mirrors the TEACHER messages list except hint system-messages are
replaced with placeholder system-messages of the same token length.
Returns (student_ids, response_mask, any_error_sites).
"""
if self.config.hint_generator is None:
return [], [], False
student_messages: list[dict] = []
# Track per-message (is_response_segment, text_for_response_mask)
# We build response_mask via segment tokenization, same pattern as
# teacher's _build_segment_mask, so the lengths match.
student_loss_segments: list[tuple[bool, str]] = []
any_errors = False
for turn in turns:
if _is_error_turn(turn):
hint_text = self.config.hint_generator(
turn.get("tool_error", "unknown"),
turn.get("error_meta", {}),
)
if hint_text:
any_errors = True
placeholder = self._make_placeholder_for_hint_length(hint_text)
# Student gets a placeholder system-msg at the SAME slot
# the teacher gets the hint system-msg.
student_messages.append({"role": "system", "content": placeholder})
student_loss_segments.append((False, placeholder))
if turn.get("content"):
student_messages.append({
"role": turn.get("role", "assistant"),
"content": turn["content"],
})
is_assistant = turn.get("role") == "assistant"
student_loss_segments.append((is_assistant, turn["content"]))
continue
if turn.get("content"):
student_messages.append({
"role": turn.get("role", "assistant"),
"content": turn["content"],
})
is_assistant = turn.get("role") == "assistant"
student_loss_segments.append((is_assistant, turn["content"]))
# Tokenize the full student conversation via apply_chat_template
# (mirrors teacher's path so chat-template markers are identical).
student_ids = self._tokenize_messages(student_messages)
# Build response mask via the same segment-tokenization helper used
# for sdpo_mask, then reinterpret 1=in-response, 0=not-in-response.
# We can't reuse _build_segment_mask (which uses ignore_index for
# non-loss); inline a 0/1 variant.
resp_mask: list[int] = []
for is_resp, text in student_loss_segments:
seg_ids = self._tokenize_text(text)
resp_mask.extend([1 if is_resp else 0] * len(seg_ids))
# Pad/truncate response_mask to student_ids length (same as teacher path).
resp_mask = resp_mask[: len(student_ids)]
if len(resp_mask) < len(student_ids):
resp_mask = resp_mask + [0] * (len(student_ids) - len(resp_mask))
return student_ids, resp_mask, any_errors
def _build_segment_mask(
self, segments: Sequence[tuple[bool, str]]
) -> list[int]:
"""For each (is_loss, text) segment, tokenize and emit per-token mask values.
Loss-active tokens get 1; non-loss tokens get -100 (ignore_index).
"""
out: list[int] = []
for is_loss, text in segments:
seg_ids = self._tokenize_text(text)
mask_value = 1 if is_loss else self.config.ignore_index
out.extend([mask_value] * len(seg_ids))
return out
# ----------------------------------------------------------------------
# Channel 3: trace-replay DPO inputs
# ----------------------------------------------------------------------
def _build_dpo_fields(
self, batch: Sequence[TraceExample]
) -> dict[str, torch.Tensor] | None:
"""Tokenize chosen/rejected pairs from teacher disagreement.
DPO accounting requires:
- chosen_input_ids = prompt + chosen_response
- rejected_input_ids = prompt + rejected_response
- response_masks indicating which tokens are response (loss-bearing) vs prompt (no loss)
"""
all_chosen: list[list[int]] = []
all_rejected: list[list[int]] = []
all_chosen_resp_mask: list[list[int]] = []
all_rejected_resp_mask: list[list[int]] = []
for ex in batch:
for pair in ex.get("dpo_pairs") or []:
prompt_msgs = pair.get("state_messages", [])
prompt_ids = self._tokenize_messages(prompt_msgs)
chosen_ids = self._tokenize_text(pair["chosen"])
rejected_ids = self._tokenize_text(pair["rejected"])
chosen_full = prompt_ids + chosen_ids
rejected_full = prompt_ids + rejected_ids
# response_mask is 0 over prompt, 1 over response
chosen_mask = [0] * len(prompt_ids) + [1] * len(chosen_ids)
rejected_mask = [0] * len(prompt_ids) + [1] * len(rejected_ids)
all_chosen.append(chosen_full)
all_rejected.append(rejected_full)
all_chosen_resp_mask.append(chosen_mask)
all_rejected_resp_mask.append(rejected_mask)
if not all_chosen:
return None # no DPO pairs in this batch
cap = self.config.max_dpo_seq_len
max_len = min(cap, max(len(s) for s in (*all_chosen, *all_rejected)))
return {
"dpo_chosen_input_ids": torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_chosen],
dtype=torch.long,
),
"dpo_chosen_response_mask": torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in all_chosen_resp_mask],
dtype=torch.long,
),
"dpo_rejected_input_ids": torch.tensor(
[_pad_or_truncate(s, max_len, self.config.pad_token_id) for s in all_rejected],
dtype=torch.long,
),
"dpo_rejected_response_mask": torch.tensor(
[_pad_or_truncate(m, max_len, 0) for m in all_rejected_resp_mask],
dtype=torch.long,
),
}
# ----------------------------------------------------------------------
# Tokenization helpers
# ----------------------------------------------------------------------
def _tokenize_trace(self, turns: Sequence[TraceTurn]) -> tuple[list[int], list[int]]:
"""Tokenize an entire trace; return (ids, response_mask).
response_mask = 1 over assistant turns (those are the loss-bearing tokens
for GRPO), 0 over user/tool turns (prompt context).
"""
all_ids: list[int] = []
resp_mask: list[int] = []
for turn in turns:
if not turn.get("content"):
continue
ids = self._tokenize_text(turn["content"])
mask_value = 1 if turn.get("role") == "assistant" else 0
all_ids.extend(ids)
resp_mask.extend([mask_value] * len(ids))
return all_ids, resp_mask
def _tokenize_text(self, text: str) -> list[int]:
"""Tokenize plain text via the tokenizer's __call__."""
result = self.tokenizer(text, add_special_tokens=False)
ids = result["input_ids"]
if hasattr(ids, "tolist"):
ids = ids.tolist()
# HF tokenizers often return list[list[int]] when batch-shaped; flatten if so
if ids and isinstance(ids[0], list):
ids = ids[0]
return list(ids)
def _tokenize_messages(self, messages: Sequence[dict]) -> list[int]:
"""Tokenize a chat-formatted list of messages.
Tries apply_chat_template first; falls back to concatenated content if not available.
NOTE: HF tokenizers' `apply_chat_template(tokenize=True)` is not
consistently typed across families. Some return `list[int]`, others
a `BatchEncoding` (a dict-like with `input_ids` key) — Qwen2.5
returns the latter. Handle both shapes here.
"""
if not messages:
return []
try:
raw = self.tokenizer.apply_chat_template(
list(messages), tokenize=True, add_generation_prompt=False
)
except (AttributeError, NotImplementedError, TypeError):
# Stub tokenizer or no chat template defined — fall back to concatenated content
text = "\n".join(m.get("content", "") for m in messages)
return self._tokenize_text(text)
# BatchEncoding (Qwen2.5 etc.): extract input_ids and unwrap if batched.
if hasattr(raw, "keys") and "input_ids" in raw:
ids = raw["input_ids"]
else:
ids = raw
if hasattr(ids, "tolist"):
ids = ids.tolist()
# If we got list[list[int]] (batch shape), unwrap the single example.
if ids and isinstance(ids[0], list):
ids = ids[0]
return list(ids)
__all__ = [
"ComposerDataCollator",
"CollatorConfig",
"TraceTurn",
"TraceExample",
"TokenizerLike",
]