Spaces:
Sleeping
Sleeping
| import httpx | |
| import json | |
| import os | |
| from typing import Any, Dict, Optional | |
| from dataclasses import dataclass | |
| class NL2SQLAction: | |
| query: str | |
| class NL2SQLObservation: | |
| question: str | |
| schema_context: str | |
| task_name: str | |
| last_query: str | |
| last_result: list | |
| last_error: Optional[str] | |
| result_columns: list | |
| step: int | |
| max_steps: int | |
| done: bool | |
| reward: float | |
| score: float | |
| class StepResult: | |
| observation: NL2SQLObservation | |
| reward: float | |
| done: bool | |
| class NL2SQLEnv: | |
| def __init__(self, base_url: str = "http://localhost:8000"): | |
| self.base_url = base_url.rstrip("/") | |
| self.client = httpx.AsyncClient(base_url=self.base_url, timeout=120.0) | |
| async def __aenter__(self): | |
| return self | |
| async def __aexit__(self, exc_type, exc_val, exc_tb): | |
| await self.client.aclose() | |
| async def reset(self) -> StepResult: | |
| task_name = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter") | |
| # Send task_name both ways β some openenv-core versions read from body, | |
| # some from the action wrapper. Belt-and-suspenders. | |
| payload = {"task_name": task_name} | |
| resp = await self.client.post("/reset", json=payload) | |
| resp.raise_for_status() | |
| return self._parse_result(resp.json()) | |
| async def step(self, action: NL2SQLAction) -> StepResult: | |
| # CRITICAL FIX: The server's action_cls=NL2SQLAction expects the payload | |
| # wrapped in {"action": {"query": ...}} per OpenEnv protocol. | |
| # Sending {"query": ...} at the top level bypasses action parsing β 0 reward. | |
| payload = {"action": {"query": action.query}} | |
| resp = await self.client.post("/step", json=payload) | |
| resp.raise_for_status() | |
| return self._parse_result(resp.json()) | |
| def _parse_result(self, payload: Dict[str, Any]) -> StepResult: | |
| obs_data = payload.get("observation", payload) | |
| # Extract reward β check top-level payload first (OpenEnv puts it there), | |
| # then fall back to nested observation dict. | |
| raw_reward = payload.get("reward") | |
| if raw_reward is None: | |
| raw_reward = obs_data.get("reward") | |
| safe_reward = float(raw_reward) if raw_reward is not None else 0.0 | |
| safe_score = float(obs_data.get("score") or 0.0) | |
| safe_done = bool(payload.get("done") or obs_data.get("done") or False) | |
| obs = NL2SQLObservation( | |
| question=obs_data.get("question", ""), | |
| schema_context=obs_data.get("schema_context", ""), | |
| task_name=obs_data.get("task_name", ""), | |
| last_query=obs_data.get("last_query", ""), | |
| last_result=obs_data.get("last_result", []), | |
| last_error=obs_data.get("last_error"), | |
| result_columns=obs_data.get("result_columns", []), | |
| step=obs_data.get("step", 0), | |
| max_steps=obs_data.get("max_steps", 5), | |
| done=safe_done, | |
| reward=safe_reward, | |
| score=safe_score, | |
| ) | |
| return StepResult( | |
| observation=obs, | |
| reward=safe_reward, | |
| done=safe_done, | |
| ) |