# inference.py import asyncio import json import os from typing import List from openai import OpenAI import httpx API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-3.5-turbo") API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY") or os.environ.get("OPENAI_API_KEY", "dummy") ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860") TASKS = os.environ.get("TASKS", "easy,medium,hard") MAX_STEPS = int(os.environ.get("MAX_STEPS", "5")) SUCCESS_SCORE_THRESHOLD = float(os.environ.get("SUCCESS_SCORE_THRESHOLD", "0.7")) MAX_TOTAL_REWARD = float(os.environ.get("MAX_TOTAL_REWARD", "1.0")) SEED = os.environ.get("SEED") MIN_LOG_REWARD = 0.01 def _parse_seed(value: str | None) -> int | None: if value is None: return None try: return int(value) except ValueError: return None def _sanitize_field(value: object) -> str: text = str(value) text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ") return " ".join(text.split()) def log_start(task, env, model): print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step, action, reward, done, error): safe_action = _sanitize_field(action) err = "null" if error is None else _sanitize_field(error) done_str = "true" if done else "false" print( f"[STEP] step={step} action={safe_action} reward={reward:.2f} done={done_str} error={err}", flush=True, ) def log_end(success, steps, rewards): success_str = "true" if success else "false" rewards_str = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={success_str} steps={steps} rewards={rewards_str}", flush=True, ) def get_model_message(client: OpenAI, observation: dict, history: List[str]) -> str: prompt = f""" You are debugging a PyTorch training job. Respond ONLY with valid JSON matching this exact schema: {{ "current_hypothesis": {{"bug_type": "", "affected_file": "", "confidence": <0.0-1.0>}}, "investigation_action": {{"action": "reveal_file", "target": ""}}, "commit_diagnosis": false, "final_diagnosis": null }} Valid action types: reveal_file, extend_loss_curve, extend_gpu_profile, reveal_log_chunk, run_diagnostic Valid bug types: missing_zero_grad, data_leakage, memory_leak, learning_rate_too_high, gradient_explosion, wrong_loss_function, amp_overflow Observation: {json.dumps(observation)[:8000]} History: {history} """ completion = client.chat.completions.create( model=MODEL_NAME, messages=[{"role": "user", "content": prompt}], temperature=0, max_tokens=500, ) return (completion.choices[0].message.content or "").strip() async def _run_task(task: str, client: OpenAI) -> None: rewards: List[float] = [] history: List[str] = [] steps_taken = 0 seed_value = _parse_seed(SEED) log_start(task=task, env="pytorch-debug-env", model=MODEL_NAME) try: async with httpx.AsyncClient(timeout=60.0) as session: reset_params = {"task_id": task} if seed_value is not None: reset_params["seed"] = seed_value reset_resp = await session.post(f"{ENV_URL}/reset", params=reset_params) reset_resp.raise_for_status() result = reset_resp.json() session_id = result.get("session_id") observation = result.get("observation") if not session_id: raise RuntimeError("Missing session_id in reset response") if observation is None: raise RuntimeError("Missing observation in reset response") for step in range(1, MAX_STEPS + 1): if result.get("done"): break action_text = "null" try: action_text = get_model_message(client, observation, history) except Exception as exc: reward = MIN_LOG_REWARD done = True error = f"model_error: {exc}" rewards.append(reward) steps_taken = step log_step(step=step, action=action_text, reward=reward, done=done, error=error) break try: action_json = json.loads(action_text) step_resp = await session.post( f"{ENV_URL}/step", params={"session_id": session_id}, json=action_json, ) step_resp.raise_for_status() result = step_resp.json() reward = result.get("reward", 0.0) done = result.get("done", False) error = result.get("error") observation = result.get("observation", observation) except Exception as exc: reward = MIN_LOG_REWARD done = True error = f"step_error: {exc}" rewards.append(reward) steps_taken = step log_step(step=step, action=action_text, reward=reward, done=done, error=error) history.append(f"step={step} reward={reward:.3f}") if done: break except Exception: pass score = min(max(rewards[-1] if rewards else 0.0, 0.0), 1.0) success = score >= SUCCESS_SCORE_THRESHOLD log_end(success=success, steps=steps_taken, rewards=rewards) async def main(): client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) tasks = [task.strip() for task in TASKS.split(",") if task.strip()] for task in tasks: await _run_task(task, client) if __name__ == "__main__": asyncio.run(main())