File size: 5,914 Bytes
8097081
 
 
 
 
 
 
 
 
7c54da3
8097081
7c54da3
8097081
7c54da3
8097081
 
 
1435892
be50021
1435892
 
 
 
 
 
 
 
 
8097081
 
5b04645
 
 
 
 
 
8097081
7c54da3
8097081
 
 
5b04645
 
7c54da3
 
5b04645
7c54da3
 
 
 
 
 
 
 
 
 
 
8097081
 
 
 
 
 
 
 
 
 
 
 
 
1435892
8097081
 
 
 
 
 
 
 
 
 
 
 
 
 
5b04645
 
 
 
1435892
7c54da3
5b04645
7c54da3
5b04645
7c54da3
1435892
 
 
 
7c54da3
 
5b04645
7c54da3
5b04645
 
 
 
 
7c54da3
 
 
 
 
5b04645
 
 
 
be50021
5b04645
 
 
 
 
 
 
7c54da3
 
 
 
 
 
 
 
 
 
 
5b04645
 
7c54da3
be50021
7c54da3
5b04645
7c54da3
 
 
 
 
 
 
 
5b04645
 
 
 
 
 
 
7c54da3
5b04645
 
 
 
 
 
8097081
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# inference.py
import asyncio
import json
import os
from typing import List

from openai import OpenAI
import httpx

API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")
API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") or os.environ.get("OPENAI_API_KEY", "dummy")
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
TASKS = os.environ.get("TASKS", "easy,medium,hard")
MAX_STEPS = int(os.environ.get("MAX_STEPS", "5"))
SUCCESS_SCORE_THRESHOLD = float(os.environ.get("SUCCESS_SCORE_THRESHOLD", "0.7"))
MAX_TOTAL_REWARD = float(os.environ.get("MAX_TOTAL_REWARD", "1.0"))
SEED = os.environ.get("SEED")
MIN_LOG_REWARD = 0.01


def _parse_seed(value: str | None) -> int | None:
    if value is None:
        return None
    try:
        return int(value)
    except ValueError:
        return None


def _sanitize_field(value: object) -> str:
    text = str(value)
    text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ")
    return " ".join(text.split())


def log_start(task, env, model):
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step, action, reward, done, error):
    safe_action = _sanitize_field(action)
    err = "null" if error is None else _sanitize_field(error)
    done_str = "true" if done else "false"
    print(
        f"[STEP] step={step} action={safe_action} reward={reward:.2f} done={done_str} error={err}",
        flush=True,
    )


def log_end(success, steps, rewards):
    success_str = "true" if success else "false"
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(
        f"[END] success={success_str} steps={steps} rewards={rewards_str}",
        flush=True,
    )


def get_model_message(client: OpenAI, observation: dict, history: List[str]) -> str:
    prompt = f"""
You are debugging a PyTorch training job. Respond ONLY with valid JSON matching this exact schema:
{{
  "current_hypothesis": {{"bug_type": "<string>", "affected_file": "<string>", "confidence": <0.0-1.0>}},
  "investigation_action": {{"action": "reveal_file", "target": "<filename>"}},
  "commit_diagnosis": false,
  "final_diagnosis": null
}}

Valid action types: reveal_file, extend_loss_curve, extend_gpu_profile, reveal_log_chunk, run_diagnostic
Valid bug types: missing_zero_grad, data_leakage, memory_leak, learning_rate_too_high, gradient_explosion, wrong_loss_function, amp_overflow

Observation:
{json.dumps(observation)[:8000]}
History: {history}
"""
    completion = client.chat.completions.create(
        model=MODEL_NAME,
        messages=[{"role": "user", "content": prompt}],
        temperature=0,
        max_tokens=500,
    )
    return (completion.choices[0].message.content or "").strip()


async def _run_task(task: str, client: OpenAI) -> None:
    rewards: List[float] = []
    history: List[str] = []
    steps_taken = 0
    seed_value = _parse_seed(SEED)

    log_start(task=task, env="pytorch-debug-env", model=MODEL_NAME)

    try:
        async with httpx.AsyncClient(timeout=60.0) as session:
            reset_params = {"task_id": task}
            if seed_value is not None:
                reset_params["seed"] = seed_value
            reset_resp = await session.post(f"{ENV_URL}/reset", params=reset_params)
            reset_resp.raise_for_status()
            result = reset_resp.json()

            session_id = result.get("session_id")
            observation = result.get("observation")
            if not session_id:
                raise RuntimeError("Missing session_id in reset response")
            if observation is None:
                raise RuntimeError("Missing observation in reset response")

            for step in range(1, MAX_STEPS + 1):
                if result.get("done"):
                    break

                action_text = "null"
                try:
                    action_text = get_model_message(client, observation, history)
                except Exception as exc:
                    reward = MIN_LOG_REWARD
                    done = True
                    error = f"model_error: {exc}"
                    rewards.append(reward)
                    steps_taken = step
                    log_step(step=step, action=action_text, reward=reward, done=done, error=error)
                    break

                try:
                    action_json = json.loads(action_text)
                    step_resp = await session.post(
                        f"{ENV_URL}/step",
                        params={"session_id": session_id},
                        json=action_json,
                    )
                    step_resp.raise_for_status()
                    result = step_resp.json()
                    reward = result.get("reward", 0.0)
                    done = result.get("done", False)
                    error = result.get("error")
                    observation = result.get("observation", observation)
                except Exception as exc:
                    reward = MIN_LOG_REWARD
                    done = True
                    error = f"step_error: {exc}"

                rewards.append(reward)
                steps_taken = step
                log_step(step=step, action=action_text, reward=reward, done=done, error=error)
                history.append(f"step={step} reward={reward:.3f}")

                if done:
                    break
    except Exception:
        pass

    score = min(max(rewards[-1] if rewards else 0.0, 0.0), 1.0)
    success = score >= SUCCESS_SCORE_THRESHOLD
    log_end(success=success, steps=steps_taken, rewards=rewards)


async def main():
    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
    tasks = [task.strip() for task in TASKS.split(",") if task.strip()]

    for task in tasks:
        await _run_task(task, client)


if __name__ == "__main__":
    asyncio.run(main())