File size: 10,187 Bytes
b74674a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Local recursive RLM runner for repl_env.

This keeps the iterative prompting/orchestration layer outside the environment,
following the same separation used by the official RLM implementation and DSPy:
- `REPLEnvironment` executes code and exposes tools
- `LocalRLMRunner` owns prompting, message history, and recursive child runs
"""

from __future__ import annotations

import re
import time
from dataclasses import dataclass
from typing import Callable

from .local import LocalREPLEnv
from .prompts import (
    build_rlm_system_prompt,
    build_user_prompt,
    extract_code_blocks,
    format_observations,
    QueryMetadata,
    RLM_SYSTEM_PROMPT,
)
from .recursive_backends import BackendLimits, LocalChildRLMBackend, RecursiveBackend


ChatFn = Callable[..., str]


@dataclass
class RLMRunResult:
    final_answer: str | None
    messages: list[dict[str, str]]
    iterations: int
    depth: int
    child_traces: list[object]


class LocalRLMRunner:
    """Local recursive RLM orchestrator built on top of LocalREPLEnv."""

    def __init__(
        self,
        llm_chat_fn: ChatFn,
        *,
        system_prompt: str = RLM_SYSTEM_PROMPT,
        max_iterations: int = 30,
        max_depth: int = 2,
        depth: int = 0,
        env_max_iterations_multiplier: int = 5,
        max_batch_workers: int = 8,
        backend_factory: Callable[..., RecursiveBackend] | None = None,
        max_children_total: int | None = None,
        max_children_per_batch: int | None = None,
        result_truncation_limit: int | None = None,
        per_child_timeout_s: float | None = None,
        on_subcall_start: Callable[[int, str, str], None] | None = None,
        on_subcall_complete: Callable[[int, str, float, str | None], None]
        | None = None,
        verbose: bool = False,
    ) -> None:
        self.llm_chat_fn = llm_chat_fn
        self.system_prompt = system_prompt
        self.max_iterations = max_iterations
        self.max_depth = max_depth
        self.depth = depth
        self.env_max_iterations_multiplier = env_max_iterations_multiplier
        self.max_batch_workers = max_batch_workers
        self.backend_factory = backend_factory or self._default_backend_factory
        self.max_children_total = max_children_total
        self.max_children_per_batch = max_children_per_batch
        self.result_truncation_limit = result_truncation_limit
        self.per_child_timeout_s = per_child_timeout_s
        self.on_subcall_start = on_subcall_start
        self.on_subcall_complete = on_subcall_complete
        self.verbose = verbose

    def _default_backend_factory(
        self, llm_chat_fn: ChatFn, **kwargs
    ) -> RecursiveBackend:
        limits = BackendLimits(
            max_depth=self.max_depth,
            max_batch_workers=self.max_batch_workers,
            max_children_total=self.max_children_total,
            max_children_per_batch=self.max_children_per_batch,
            result_truncation_limit=self.result_truncation_limit,
            per_child_timeout_s=self.per_child_timeout_s,
        )
        return LocalChildRLMBackend(
            llm_chat_fn,
            runner_factory=LocalRLMRunner,
            system_prompt=kwargs["system_prompt"],
            max_iterations=kwargs["max_iterations"],
            env_max_iterations_multiplier=kwargs["env_max_iterations_multiplier"],
            depth=kwargs["depth"],
            limits=limits,
            on_subcall_start=self.on_subcall_start,
            on_subcall_complete=self.on_subcall_complete,
        )

    def run(
        self,
        context: str,
        task_prompt: str,
        *,
        model: str | None = None,
        timeout_s: float | None = None,
    ) -> RLMRunResult:
        backend = self.backend_factory(
            self.llm_chat_fn,
            system_prompt=self.system_prompt,
            max_iterations=self.max_iterations,
            max_depth=self.max_depth,
            depth=self.depth,
            env_max_iterations_multiplier=self.env_max_iterations_multiplier,
        )
        with LocalREPLEnv(
            llm_query_fn=backend.query,
            llm_batch_fn=backend.query_batched,
            subcall_fn=backend.recursive_query,
            subcall_batch_fn=backend.recursive_query_batched,
        ) as env:
            result = env.reset(
                context=context,
                task_prompt=task_prompt,
                max_iterations=self.max_iterations * self.env_max_iterations_multiplier,
                llm_model=model,
            )
            obs = result.observation

            query_metadata = QueryMetadata(
                context_lengths=[obs.context_length],
                context_total_length=obs.context_length,
                context_type="str",
            )
            messages = build_rlm_system_prompt(self.system_prompt, query_metadata)
            messages.append(build_user_prompt(root_prompt=task_prompt, iteration=0))

            run_start = time.perf_counter()

            for iteration in range(1, self.max_iterations + 1):
                # Cooperative timeout check (matches official RLM pattern)
                if timeout_s is not None:
                    elapsed = time.perf_counter() - run_start
                    if elapsed >= timeout_s:
                        return RLMRunResult(
                            final_answer=f"Error: child timeout after {elapsed:.3f}s",
                            messages=messages,
                            iterations=iteration - 1,
                            depth=self.depth,
                            child_traces=list(getattr(backend, "child_traces", [])),
                        )

                response = self._chat(messages, model)
                code_blocks = extract_code_blocks(response)
                code_block_observations = []

                if self.verbose:
                    print(
                        f"[depth={self.depth}] iteration={iteration} code_blocks={len(code_blocks)}"
                    )

                if not code_blocks:
                    messages.append({"role": "assistant", "content": response})
                    messages.append(
                        {
                            "role": "user",
                            "content": (
                                "Please continue by writing Python code in ```repl``` blocks, "
                                "or submit the final answer with FINAL(...) / FINAL_VAR(...)."
                            ),
                        }
                    )
                    continue

                for code in code_blocks:
                    result = env.execute(code)
                    code_block_observations.append(result.observation)

                # Check for FINAL after all blocks executed (matches official RLM).
                # The model expects all blocks to run — it often writes exploration
                # code first and FINAL last in the same response.
                if any(obs.done for obs in code_block_observations):
                    return RLMRunResult(
                        final_answer=env.state().final_answer,
                        messages=messages
                        + [{"role": "assistant", "content": response}],
                        iterations=iteration,
                        depth=self.depth,
                        child_traces=list(getattr(backend, "child_traces", [])),
                    )

                observation_text = format_observations(
                    code_block_observations, code_blocks=code_blocks
                )
                next_prompt = build_user_prompt(
                    root_prompt=task_prompt,
                    iteration=iteration,
                )
                messages.append({"role": "assistant", "content": response})
                messages.append(
                    {
                        "role": "user",
                        "content": observation_text + "\n\n" + next_prompt["content"],
                    }
                )

            # Max iterations exhausted — give the model one final chance to answer
            final_answer = env.state().final_answer
            if final_answer is None:
                final_answer = self._default_answer(messages, model)

            return RLMRunResult(
                final_answer=final_answer,
                messages=messages,
                iterations=self.max_iterations,
                depth=self.depth,
                child_traces=list(getattr(backend, "child_traces", [])),
            )

    def _default_answer(
        self, messages: list[dict[str, str]], model: str | None = None
    ) -> str | None:
        """Make one final LLM call asking for an answer when iterations are exhausted."""
        final_prompt = messages + [
            {
                "role": "user",
                "content": (
                    "You have run out of REPL iterations. Based on all your work above, "
                    "provide your best final answer now. Use FINAL(your answer) to submit it. "
                    "If you stored the answer in a variable, use FINAL_VAR(variable_name) instead. "
                    "Do not write any more code — just provide the final answer."
                ),
            }
        ]
        try:
            response = self._chat(final_prompt, model)
            # Try to extract FINAL(...) from the response
            match = re.search(r"FINAL\((.*?)\)", response, re.DOTALL)
            if match:
                return match.group(1).strip()
            # If no FINAL pattern, return the raw response as best-effort
            return response.strip() if response.strip() else None
        except Exception:
            return None

    def _chat(self, messages: list[dict[str, str]], model: str | None = None) -> str:
        try:
            return self.llm_chat_fn(messages, model)
        except TypeError:
            return self.llm_chat_fn(messages)