File size: 9,173 Bytes
b74674a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Recursive backend abstractions for repl_env.

This module keeps direct LM calls and recursive child spawning out of the
runner and environment. The runner owns the iterative loop; the backend owns
query/query_batched/child-recursion behavior.
"""

from __future__ import annotations

import threading
import time
from concurrent.futures import as_completed, ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Callable, Protocol


ChatFn = Callable[..., str]


class RecursiveBackend(Protocol):
    max_depth: int
    depth: int
    child_traces: list["ChildTrace"]

    def query(self, prompt: str, model: str | None = None) -> str: ...

    def query_batched(
        self, prompts: list[str], model: str | None = None
    ) -> list[str]: ...

    def recursive_query(self, prompt: str, model: str | None = None) -> str: ...

    def recursive_query_batched(
        self, prompts: list[str], model: str | None = None
    ) -> list[str]: ...


@dataclass
class BackendLimits:
    max_depth: int = 1
    max_batch_workers: int = 8
    max_children_total: int | None = None
    max_children_per_batch: int | None = None
    result_truncation_limit: int | None = None
    # Cooperative timeout: checked between iterations, not during LLM calls.
    # A slow LLM call within an iteration will not be interrupted — the timeout
    # fires at the next iteration boundary. For mid-call cancellation, use
    # process-based isolation instead.
    per_child_timeout_s: float | None = None
    # Tree-global child counter shared across all recursion depths
    _children_spawned: int = field(default=0, init=False, repr=False)
    _children_lock: threading.Lock = field(
        default_factory=threading.Lock, init=False, repr=False
    )


@dataclass
class ChildTrace:
    depth: int
    duration_s: float
    prompt_preview: str
    result_preview: str | None
    error: str | None


class DirectLMBackend:
    """Direct LM backend with no child recursion beyond fallback to itself."""

    def __init__(
        self,
        llm_chat_fn: ChatFn,
        *,
        depth: int = 0,
        limits: BackendLimits | None = None,
    ) -> None:
        self.llm_chat_fn = llm_chat_fn
        self.depth = depth
        self.limits = limits or BackendLimits()
        self.max_depth = self.limits.max_depth
        self.child_traces: list[ChildTrace] = []

    def query(self, prompt: str, model: str | None = None) -> str:
        try:
            result = self.llm_chat_fn([{"role": "user", "content": prompt}], model)
        except TypeError:
            result = self.llm_chat_fn([{"role": "user", "content": prompt}])
        return self._truncate(result)

    def query_batched(self, prompts: list[str], model: str | None = None) -> list[str]:
        if not prompts:
            return []
        max_workers = min(len(prompts), self.limits.max_batch_workers)
        results: list[str] = [""] * len(prompts)
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_idx = {
                executor.submit(self.query, prompt, model): idx
                for idx, prompt in enumerate(prompts)
            }
            for future in as_completed(future_to_idx):
                idx = future_to_idx[future]
                try:
                    results[idx] = future.result()
                except Exception as exc:
                    results[idx] = f"Error: {exc}"
        return results

    def recursive_query(self, prompt: str, model: str | None = None) -> str:
        return self.query(prompt, model)

    def recursive_query_batched(
        self, prompts: list[str], model: str | None = None
    ) -> list[str]:
        return self.query_batched(prompts, model)

    def _truncate(self, result: str) -> str:
        limit = self.limits.result_truncation_limit
        if limit is not None and len(result) > limit:
            return result[:limit]
        return result


class LocalChildRLMBackend(DirectLMBackend):
    """Recursive backend that spawns child LocalRLMRunner instances."""

    def __init__(
        self,
        llm_chat_fn: ChatFn,
        *,
        runner_factory: Callable[..., object],
        system_prompt: str,
        max_iterations: int,
        env_max_iterations_multiplier: int,
        depth: int = 0,
        limits: BackendLimits | None = None,
        on_subcall_start: Callable[[int, str, str], None] | None = None,
        on_subcall_complete: Callable[[int, str, float, str | None], None]
        | None = None,
    ) -> None:
        super().__init__(llm_chat_fn, depth=depth, limits=limits)
        self.runner_factory = runner_factory
        self.system_prompt = system_prompt
        self.max_iterations = max_iterations
        self.env_max_iterations_multiplier = env_max_iterations_multiplier
        self.on_subcall_start = on_subcall_start
        self.on_subcall_complete = on_subcall_complete

    def recursive_query(self, prompt: str, model: str | None = None) -> str:
        next_depth = self.depth + 1
        if next_depth >= self.max_depth:
            return self.query(prompt, model)
        with self.limits._children_lock:
            if self.limits.max_children_total is not None:
                if self.limits._children_spawned >= self.limits.max_children_total:
                    return "Error: max_children_total exceeded"
            self.limits._children_spawned += 1
        start = time.perf_counter()
        error: str | None = None
        result_text = ""
        resolved_model = model or "default"
        if self.on_subcall_start is not None:
            try:
                self.on_subcall_start(next_depth, str(resolved_model), prompt[:80])
            except Exception:
                pass
        try:
            child = self.runner_factory(
                self.llm_chat_fn,
                system_prompt=self.system_prompt,
                max_iterations=self.max_iterations,
                max_depth=self.max_depth,
                depth=next_depth,
                env_max_iterations_multiplier=self.env_max_iterations_multiplier,
                max_batch_workers=self.limits.max_batch_workers,
                backend_factory=self._child_backend_factory,
                on_subcall_start=self.on_subcall_start,
                on_subcall_complete=self.on_subcall_complete,
            )
            result = child.run(
                prompt, prompt, model=model, timeout_s=self.limits.per_child_timeout_s
            )
            result_text = self._truncate(result.final_answer or "")
            return result_text
        except Exception as exc:
            error = str(exc)
            raise
        finally:
            duration = time.perf_counter() - start
            self.child_traces.append(
                ChildTrace(
                    depth=next_depth,
                    duration_s=duration,
                    prompt_preview=prompt[:80],
                    result_preview=(result_text[:80] if result_text else None),
                    error=error,
                )
            )
            if self.on_subcall_complete is not None:
                try:
                    self.on_subcall_complete(
                        next_depth,
                        str(resolved_model),
                        duration,
                        error,
                    )
                except Exception:
                    pass

    def recursive_query_batched(
        self, prompts: list[str], model: str | None = None
    ) -> list[str]:
        if not prompts:
            return []
        batch_limit = self.limits.max_children_per_batch
        if batch_limit is not None:
            prompts = prompts[:batch_limit]
        max_workers = min(len(prompts), self.limits.max_batch_workers)
        results: list[str] = [""] * len(prompts)
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_idx = {
                executor.submit(self.recursive_query, prompt, model): idx
                for idx, prompt in enumerate(prompts)
            }
            for future in as_completed(future_to_idx):
                idx = future_to_idx[future]
                try:
                    results[idx] = future.result()
                except Exception as exc:
                    results[idx] = f"Error: {exc}"
        return results

    def _child_backend_factory(
        self, llm_chat_fn: ChatFn, **kwargs
    ) -> "LocalChildRLMBackend":
        return LocalChildRLMBackend(
            llm_chat_fn,
            runner_factory=self.runner_factory,
            system_prompt=kwargs["system_prompt"],
            max_iterations=kwargs["max_iterations"],
            env_max_iterations_multiplier=kwargs["env_max_iterations_multiplier"],
            depth=kwargs["depth"],
            limits=self.limits,
            on_subcall_start=self.on_subcall_start,
            on_subcall_complete=self.on_subcall_complete,
        )