nl2sql-bench / client.py
ritvik360's picture
Upload folder using huggingface_hub
1a16689 verified
import httpx
import json
import os
from typing import Any, Dict, Optional
from dataclasses import dataclass
@dataclass
class NL2SQLAction:
query: str
@dataclass
class NL2SQLObservation:
question: str
schema_context: str
task_name: str
last_query: str
last_result: list
last_error: Optional[str]
result_columns: list
step: int
max_steps: int
done: bool
reward: float
score: float
@dataclass
class StepResult:
observation: NL2SQLObservation
reward: float
done: bool
class NL2SQLEnv:
def __init__(self, base_url: str = "http://localhost:8000"):
self.base_url = base_url.rstrip("/")
self.client = httpx.AsyncClient(base_url=self.base_url, timeout=120.0)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.client.aclose()
async def reset(self) -> StepResult:
task_name = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter")
# Send task_name both ways β€” some openenv-core versions read from body,
# some from the action wrapper. Belt-and-suspenders.
payload = {"task_name": task_name}
resp = await self.client.post("/reset", json=payload)
resp.raise_for_status()
return self._parse_result(resp.json())
async def step(self, action: NL2SQLAction) -> StepResult:
# CRITICAL FIX: The server's action_cls=NL2SQLAction expects the payload
# wrapped in {"action": {"query": ...}} per OpenEnv protocol.
# Sending {"query": ...} at the top level bypasses action parsing β†’ 0 reward.
payload = {"action": {"query": action.query}}
resp = await self.client.post("/step", json=payload)
resp.raise_for_status()
return self._parse_result(resp.json())
def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
obs_data = payload.get("observation", payload)
# Extract reward β€” check top-level payload first (OpenEnv puts it there),
# then fall back to nested observation dict.
raw_reward = payload.get("reward")
if raw_reward is None:
raw_reward = obs_data.get("reward")
safe_reward = float(raw_reward) if raw_reward is not None else 0.0
safe_score = float(obs_data.get("score") or 0.0)
safe_done = bool(payload.get("done") or obs_data.get("done") or False)
obs = NL2SQLObservation(
question=obs_data.get("question", ""),
schema_context=obs_data.get("schema_context", ""),
task_name=obs_data.get("task_name", ""),
last_query=obs_data.get("last_query", ""),
last_result=obs_data.get("last_result", []),
last_error=obs_data.get("last_error"),
result_columns=obs_data.get("result_columns", []),
step=obs_data.get("step", 0),
max_steps=obs_data.get("max_steps", 5),
done=safe_done,
reward=safe_reward,
score=safe_score,
)
return StepResult(
observation=obs,
reward=safe_reward,
done=safe_done,
)