Codeseys's picture
feat(wave-a): close ADR-011 (SDPO alignment indices) + ADR-012 (review findings)
d02d724
"""hint_generator.py — Template-based hint generator (v0.1 starter).
Composer 2.5 inserts text hints at error-turn sites:
"Reminder: Available tools are: …" (when a tool-call refs a non-existent tool)
"Reminder: tool arguments must be valid JSON" (on JSONDecodeError)
... etc.
This module provides a registry of hint templates keyed by error_kind. The
data collator (in trl_path/data_collator.py) calls dispatch(error_kind, ctx)
to get the hint text to splice into ctx_teacher.
v0.2 will replace these templates with an LLM-driven hint generator (likely
Sonnet 4.6 or Opus 4.7 via OpenRouter) for cases where templates are too rigid
(style violations, wasteful explanations).
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TypedDict
class HintContext(TypedDict, total=False):
"""Per-error context the hint generator can use."""
error_kind: str # e.g. "tool_not_found", "json_decode", "type_error"
error_message: str # raw error from the env
available_tools: list[str] # for tool_not_found
tool_name: str # the failing tool, if known
tool_schema: dict # the schema, if known
intent: str # student's apparent intent, if extractable
# ---------------------------------------------------------------------------
# Hint templates
# ---------------------------------------------------------------------------
def hint_tool_not_found(ctx: HintContext) -> str:
tools = ctx.get("available_tools", [])
if tools:
tool_list = ", ".join(f"`{t}`" for t in tools)
return f"Reminder: Available tools are: {tool_list}. Please use one of these."
return "Reminder: the tool you tried to call does not exist. Use only available tools."
def hint_json_decode(ctx: HintContext) -> str:
return (
"Reminder: tool arguments must be valid JSON. Common mistakes: "
"single quotes (use double), trailing commas, unescaped newlines in strings."
)
def hint_type_error(ctx: HintContext) -> str:
name = ctx.get("tool_name")
schema = ctx.get("tool_schema")
if name and schema:
return (
f"Reminder: `{name}` expects arguments matching this schema:\n"
f" {schema}\n"
"Re-issue the call with arguments matching the schema."
)
return "Reminder: tool arguments do not match the expected types. Check the schema."
def hint_runtime_error(ctx: HintContext) -> str:
msg = ctx.get("error_message", "an exception")
return (
f"Reminder: the previous tool call raised {msg}. "
"Reconsider the inputs or read the relevant code first to understand state."
)
def hint_repeated_failure(ctx: HintContext) -> str:
"""Triggered when the same kind of error happens 3+ times in a row."""
return (
"Reminder: this approach has failed multiple times. "
"Step back and consider an alternative approach: read more files, "
"search for similar patterns elsewhere, or break the task down differently."
)
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
HINT_TEMPLATES: dict[str, Callable[[HintContext], str]] = {
"tool_not_found": hint_tool_not_found,
"json_decode": hint_json_decode,
"type_error": hint_type_error,
"runtime_error": hint_runtime_error,
"repeated_failure": hint_repeated_failure,
}
def dispatch(error_kind: str, ctx: HintContext | None = None) -> str | None:
"""Generate a hint for the given error_kind. Returns None if unknown."""
fn = HINT_TEMPLATES.get(error_kind)
if fn is None:
return None
return fn(ctx or {})
def register(error_kind: str, fn: Callable[[HintContext], str]) -> None:
"""Add a custom hint template."""
HINT_TEMPLATES[error_kind] = fn
# ===========================================================================
# Layered HintGenerator architecture (ADR-009)
# ===========================================================================
#
# Composer 2.5 inserts a natural-language hint at each error turn; the
# hint-conditioned forward becomes the SDPO teacher. HOW Cursor generates the
# hint is unstated in every Cursor artifact (both blogs + the Composer 2 tech
# report, arXiv:2603.24477 — confirmed absent in research/10). So this is our
# design problem. The cited papers bracket the answer: OPSD conditions the
# teacher on ground-truth; SDPO generalizes to environment feedback and the
# "successful sibling rollout as implicit feedback" trick.
#
# We implement a layered generator, tried cheapest-first:
# 1. TemplateHintGenerator — the registry above (free, deterministic;
# covers tool-error classes). The first layer.
# 2. RawErrorHintGenerator — wrap the raw env/tool error text as the hint
# (free; covers any error with a message but unmatched by a template).
# 3. LLMJudgeHintGenerator — an LLM produces a <=2-sentence corrective hint
# (cost ~$0.0005/site; covers style/communication/effort sites templates
# can't). Cached on disk; optional; OFF unless a client is provided.
# 4. (sibling-bootstrap) — RL-rollout-path only; not a HintContext-driven
# layer (needs sibling rollouts), exposed as a flag for the trainer to use.
#
# All layers satisfy the HintGenerator Protocol and compose via
# CompositeHintGenerator, whose .as_collator_hook() returns a callable matching
# the collator's existing `hint_generator: Callable[[str, dict], str | None]`
# hook — ZERO collator change.
from typing import Protocol, runtime_checkable
@runtime_checkable
class HintGenerator(Protocol):
"""A hint source. Returns hint text for an error context, or None to defer
to the next layer."""
def generate(self, error_kind: str, error_meta: dict) -> str | None: ...
class TemplateHintGenerator:
"""Layer 1: the existing template registry. Free, deterministic.
Preserves the exact behavior of the module-level `dispatch()` so existing
callers and tests see no change.
"""
def generate(self, error_kind: str, error_meta: dict) -> str | None:
# `dispatch` reads HintContext keys; error_meta IS that context dict
# plus the kind. Merge so templates that read `error_kind` still work.
ctx: HintContext = dict(error_meta) # type: ignore[assignment]
ctx.setdefault("error_kind", error_kind)
return dispatch(error_kind, ctx)
class RawErrorHintGenerator:
"""Layer 2: use the raw env/tool error text itself as the hint.
Covers any error site that carries a message but isn't matched by a
template. Free. SDPO's "environment feedback as the conditioning signal"
(arXiv:2601.20802) — the rawest form of that.
"""
def __init__(self, max_chars: int = 500) -> None:
self.max_chars = max_chars
def generate(self, error_kind: str, error_meta: dict) -> str | None:
msg = error_meta.get("error_message") or error_meta.get("error") or ""
msg = str(msg).strip()
if not msg:
return None
truncated = msg[: self.max_chars]
return f"Reminder: the previous action produced this error:\n{truncated}\nReconsider and retry."
# ---------------------------------------------------------------------------
# Error-kind routing (ADR-012 finding #2)
# ---------------------------------------------------------------------------
#
# The default composite is template -> raw-error -> judge. The raw-error layer
# fires for ANY kind carrying a message — including style/communication/effort
# sites, which are EXACTLY what the LLM judge exists to cover. So we route:
# tool/runtime error kinds may use the raw-error layer; style/communication/
# effort kinds skip it and fall through to the judge.
# Error kinds that genuinely describe a tool/runtime failure whose raw text is a
# useful, self-contained hint. The explicit registry-template kinds are included
# so behavior is unchanged for them.
_TOOL_RUNTIME_KINDS: frozenset[str] = frozenset({
"tool_not_found",
"json_decode",
"type_error",
"runtime_error",
"repeated_failure",
})
# Substrings marking a kind as tool/runtime-ish even if not explicitly listed
# (keeps generic "*_error"/"*_exception" sites flowing through raw-error, which
# is where their raw text belongs).
_TOOL_RUNTIME_MARKERS: tuple[str, ...] = (
"error", "exception", "fail", "decode", "timeout", "traceback",
"exit_code", "nonzero", "syntax", "import", "assertion", "tool",
"runtime", "crash", "exec",
)
# Substrings marking a kind as a style/communication/effort site — the judge's
# domain. These take precedence: a kind matching one of these skips raw-error.
_STYLE_KINDS_MARKERS: tuple[str, ...] = (
"style", "communic", "verbose", "effort", "concise", "tone",
"format", "wordy", "rambl", "explanation", "etiquette", "clarity",
)
def is_tool_runtime_kind(error_kind: str) -> bool:
"""True if `error_kind` is a tool/runtime failure that the raw-error layer
may serve. Style/communication/effort kinds return False (-> judge)."""
k = (error_kind or "").lower()
if any(m in k for m in _STYLE_KINDS_MARKERS):
return False
if k in _TOOL_RUNTIME_KINDS:
return True
return any(m in k for m in _TOOL_RUNTIME_MARKERS)
class RoutingHintGenerator:
"""Wraps an inner layer (the raw-error layer) and only lets it fire for
tool/runtime error kinds. For style/communication/effort kinds it returns
None so the composite falls through to the judge — the layer those sites
were always meant to reach (ADR-012 finding #2).
"""
def __init__(self, inner: HintGenerator, route=is_tool_runtime_kind) -> None:
self.inner = inner
self.route = route
def generate(self, error_kind: str, error_meta: dict) -> str | None:
if not self.route(error_kind):
return None
return self.inner.generate(error_kind, error_meta)
class LLMJudgeHintGenerator:
"""Layer 3: an LLM produces a short corrective hint.
Covers style/communication/effort sites that templates can't. Optional and
OFF unless a `complete` callable is provided. Results are cached on disk
keyed on a hash of the error context (so repeated identical sites cost
nothing after the first).
`complete(prompt: str) -> str` is an injected text-completion callable
(e.g. an OpenRouter chat wrapper). Kept abstract so this module has no hard
network dependency and is unit-testable with a stub.
"""
PROMPT_TEMPLATE = (
"An autonomous coding agent made a mistake at one step of a trajectory. "
"Write a SHORT (<=2 sentences) corrective hint that, if the agent had "
"seen it, would steer it to the right behavior for THIS step only. Do "
"not solve the whole task; just correct the local mistake.\n\n"
"Error kind: {error_kind}\n"
"Error / context:\n{error_message}\n\n"
"Corrective hint:"
)
# Bump when PROMPT_TEMPLATE or the underlying judge model changes so stale
# cached hints are invalidated rather than silently reused.
_CACHE_VERSION = 2
# Hard cap on a generated hint. The judge is asked for <=2 sentences but
# nothing enforced it (cross-family review 2026-05-29) — a runaway judge
# could emit a full solution / prompt-leak / megabyte of text straight into
# the SDPO teacher conditioning. Clamp defensively.
_MAX_HINT_CHARS = 600
def __init__(
self,
complete: Callable[[str], str] | None = None,
*,
cache_dir: str | None = None,
) -> None:
self.complete = complete
self._cache_dir = cache_dir
self._mem_cache: dict[str, str] = {}
def _cache_key(self, error_kind: str, error_meta: dict) -> str:
import hashlib
import json
import re
# Strip volatile object reprs (e.g. "<Exception at 0x7f8b...>") so the
# key is stable across runs/restarts. Cross-family review 2026-05-29:
# `default=str` on raw Exception/context objects embedded a memory
# address in the key, guaranteeing a 0% cross-process cache-hit rate and
# unbounded judge cost. Also version the key so prompt/model changes
# invalidate stale hints rather than serving them.
blob = json.dumps(
{"v": self._CACHE_VERSION, "k": error_kind, "m": error_meta},
sort_keys=True, default=str,
)
blob = re.sub(r"0x[0-9a-fA-F]+", "0xADDR", blob)
blob = re.sub(r"\bat 0xADDR\b", "", blob)
return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:32]
def _disk_get(self, key: str) -> str | None:
if not self._cache_dir:
return None
from pathlib import Path
p = Path(self._cache_dir) / f"{key}.txt"
return p.read_text(encoding="utf-8") if p.exists() else None
def _disk_put(self, key: str, value: str) -> None:
if not self._cache_dir:
return
import os
from pathlib import Path
d = Path(self._cache_dir)
d.mkdir(parents=True, exist_ok=True)
# Atomic write: concurrent DDP workers writing the same key would
# otherwise interleave and corrupt the file (cross-family review).
tmp = d / f"{key}.txt.{os.getpid()}.tmp"
tmp.write_text(value, encoding="utf-8")
os.replace(tmp, d / f"{key}.txt")
def generate(self, error_kind: str, error_meta: dict) -> str | None:
if self.complete is None:
return None # judge disabled — defer
key = self._cache_key(error_kind, error_meta)
if key in self._mem_cache:
return self._mem_cache[key]
cached = self._disk_get(key)
if cached is not None:
self._mem_cache[key] = cached
return cached
prompt = self.PROMPT_TEMPLATE.format(
error_kind=error_kind,
error_message=str(error_meta.get("error_message")
or error_meta.get("error") or "(no message)")[:1000],
)
hint = self.complete(prompt).strip()
if not hint:
return None
# Clamp to a sane length so a runaway judge can't inject a full solution
# or megabyte blob into the SDPO teacher conditioning (cross-family review).
if len(hint) > self._MAX_HINT_CHARS:
hint = hint[: self._MAX_HINT_CHARS].rstrip() + "…"
self._mem_cache[key] = hint
self._disk_put(key, hint)
return hint
class CompositeHintGenerator:
"""Tries each layer in order, returning the first non-None hint.
Order is cost-ascending: templates (free) -> raw error (free) -> LLM judge
(paid, optional). The first layer to produce a hint wins, so the common
tool-error case never reaches the LLM.
"""
def __init__(self, layers: list[HintGenerator]) -> None:
self.layers = layers
def generate(self, error_kind: str, error_meta: dict) -> str | None:
for layer in self.layers:
hint = layer.generate(error_kind, error_meta)
if hint is not None:
return hint
return None
def as_collator_hook(self) -> Callable[[str, dict], str | None]:
"""Return a callable matching CollatorConfig.hint_generator's signature
(error_kind, error_meta) -> str | None. ZERO collator change."""
return self.generate
def default_composite(
*,
llm_complete: Callable[[str], str] | None = None,
cache_dir: str | None = None,
enable_raw_error: bool = True,
) -> CompositeHintGenerator:
"""Build the recommended layered generator: templates -> raw-error -> judge.
The raw-error layer is wrapped in a RoutingHintGenerator so it only fires for
tool/runtime error kinds; style/communication/effort kinds skip it and fall
through to the LLM judge (ADR-012 finding #2). The LLM-judge layer is
included only when `llm_complete` is provided.
"""
layers: list[HintGenerator] = [TemplateHintGenerator()]
if enable_raw_error:
layers.append(RoutingHintGenerator(RawErrorHintGenerator()))
if llm_complete is not None:
layers.append(LLMJudgeHintGenerator(llm_complete, cache_dir=cache_dir))
return CompositeHintGenerator(layers)
__all__ = [
"dispatch",
"register",
"HintContext",
"HINT_TEMPLATES",
# Layered architecture (ADR-009)
"HintGenerator",
"TemplateHintGenerator",
"RawErrorHintGenerator",
"RoutingHintGenerator",
"is_tool_runtime_kind",
"LLMJudgeHintGenerator",
"CompositeHintGenerator",
"default_composite",
]