Spaces:
Sleeping
Sleeping
| """ | |
| nl2sql-bench/server/environment.py | |
| ==================================== | |
| NL2SQL-Bench core environment β implements the OpenEnv Environment interface. | |
| Episode flow | |
| ------------ | |
| 1. reset(task_name?) β picks a task + question, returns initial observation | |
| 2. step(action) β executes the SQL, grades it, returns observation + reward | |
| 3. state() β returns episode metadata | |
| 4. Episode ends when: exact_match OR step count reaches max_steps | |
| The environment manages its own SQLite connection (in-memory, seeded | |
| deterministically). One connection per Environment instance; the FastAPI | |
| server creates one Environment per WebSocket session. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sqlite3 | |
| import uuid | |
| from pathlib import Path | |
| from typing import Optional | |
| from openenv.core.env_server import Environment | |
| # Import after openenv so path is correct regardless of working directory | |
| _HERE = Path(__file__).parent | |
| # Lazy import of task registry (avoids circular imports) | |
| from tasks import get_task, all_task_names, BaseTask | |
| from tasks.base import TaskExample | |
| from grader import ( | |
| GradeResult, | |
| compute_ground_truth, | |
| execute_query, | |
| grade, | |
| has_order_by, | |
| ) | |
| # We import our models from one level up (models.py at project root) | |
| import sys | |
| sys.path.insert(0, str(_HERE.parent)) | |
| from models import NL2SQLAction, NL2SQLObservation, NL2SQLState | |
| # ββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEFAULT_TASK = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter") | |
| MAX_STEPS = int(os.getenv("NL2SQL_MAX_STEPS", "5")) | |
| RESULT_LIMIT = 10 # Max rows shown to agent per step | |
| class NL2SQLEnvironment(Environment): | |
| """ | |
| OpenEnv-compliant environment for NL-to-SQL query generation. | |
| One instance per WebSocket session (created by create_fastapi_app). | |
| """ | |
| def __init__(self) -> None: | |
| self._conn: Optional[sqlite3.Connection] = None | |
| self._task: Optional[BaseTask] = None | |
| self._example: Optional[TaskExample] = None | |
| self._ground_truth: list = [] | |
| self._order_sensitive: bool = False | |
| self._state = NL2SQLState( | |
| episode_id=None, | |
| step_count=0, | |
| task_name="", | |
| task_difficulty="", | |
| question="", | |
| best_reward=0.0, | |
| cumulative_reward=0.0, | |
| solved=False | |
| ) | |
| self._last_obs = NL2SQLObservation( | |
| question="", | |
| schema_context="", | |
| task_name="", | |
| last_query="", | |
| last_result=[], | |
| last_error=None, | |
| result_columns=[], | |
| step=0, | |
| max_steps=5, | |
| done=False, | |
| reward=None, | |
| score=0.0 | |
| ) | |
| self._episode_rewards: list = [] | |
| self._setup_db() | |
| # ββ DB lifecycle βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _setup_db(self) -> None: | |
| """Create in-memory SQLite DB and seed it.""" | |
| schema_path = _HERE / "db" / "schema.sql" | |
| from db.seed import seed_database # local import after sys.path setup | |
| conn = sqlite3.connect(":memory:", check_same_thread=False) | |
| conn.row_factory = sqlite3.Row | |
| conn.execute("PRAGMA foreign_keys = ON") | |
| conn.executescript(schema_path.read_text()) | |
| seed_database(conn) | |
| self._conn = conn | |
| # ββ OpenEnv interface ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, task_name: Optional[str] = None) -> NL2SQLObservation: | |
| """ | |
| Start a new episode. | |
| task_name: one of 'simple-filter', 'join-aggregation', 'analytics-window'. | |
| Defaults to NL2SQL_DEFAULT_TASK env-var or 'simple-filter'. | |
| """ | |
| task_name = task_name or DEFAULT_TASK | |
| if task_name not in all_task_names(): | |
| task_name = DEFAULT_TASK | |
| self._task = get_task(task_name) | |
| self._example = self._task.next_example() | |
| self._order_sensitive = has_order_by(self._example.sql) | |
| # Pre-compute ground truth once per episode | |
| self._ground_truth = compute_ground_truth(self._conn, self._example.sql) | |
| self._episode_rewards = [] | |
| self._state = NL2SQLState( | |
| episode_id=str(uuid.uuid4()), | |
| step_count=0, | |
| task_name=self._task.name, | |
| task_difficulty=self._task.difficulty, | |
| question=self._example.question, | |
| best_reward=0.0, | |
| cumulative_reward=0.0, | |
| solved=False, | |
| ) | |
| obs = NL2SQLObservation( | |
| question=self._example.question, | |
| schema_context=self._task.schema_context(), | |
| task_name=self._task.name, | |
| last_query="", | |
| last_result=[], | |
| last_error=None, | |
| result_columns=[], | |
| step=0, | |
| max_steps=MAX_STEPS, | |
| done=False, | |
| reward=None, | |
| score=0.0, | |
| ) | |
| self._last_obs = obs | |
| return obs | |
| def step(self, action: NL2SQLAction) -> NL2SQLObservation: | |
| """Execute the agent's SQL and return graded observation.""" | |
| if self._task is None or self._example is None: | |
| # Called before reset β auto-reset | |
| self.reset() | |
| self._state.step_count += 1 | |
| current_step = self._state.step_count | |
| done = False | |
| # Execute the query | |
| rows, error = execute_query(self._conn, action.query) | |
| # Grade it | |
| result: GradeResult = grade( | |
| actual_rows=rows, | |
| ground_truth_rows=self._ground_truth, | |
| error=error, | |
| step=current_step, | |
| order_sensitive=self._order_sensitive, | |
| ) | |
| reward = result.reward | |
| self._episode_rewards.append(reward) | |
| self._state.cumulative_reward += reward | |
| self._state.best_reward = max(self._state.best_reward, reward) | |
| if result.exact_match: | |
| self._state.solved = True | |
| done = True | |
| elif current_step >= MAX_STEPS: | |
| done = True | |
| # Prepare result rows for observation (truncated for agent readability) | |
| display_rows = (rows or [])[:RESULT_LIMIT] | |
| result_columns = list(display_rows[0].keys()) if display_rows else [] | |
| # Convert sqlite3.Row objects if needed | |
| display_rows = [dict(r) for r in display_rows] | |
| # Normalised cumulative score | |
| n = len(self._episode_rewards) | |
| score = self._state.cumulative_reward / max(n, 1) if n else 0.0 | |
| score = round(min(max(score, 0.0), 1.0), 4) | |
| obs = NL2SQLObservation( | |
| question=self._example.question, | |
| schema_context=self._task.schema_context(), | |
| task_name=self._task.name, | |
| last_query=action.query, | |
| last_result=display_rows, | |
| last_error=error, | |
| result_columns=result_columns, | |
| step=current_step, | |
| max_steps=MAX_STEPS, | |
| done=done, | |
| reward=reward, | |
| score=score, | |
| ) | |
| self._last_obs = obs | |
| # openenv-core expects ONLY the observation returned from step(). | |
| # The framework reads obs.reward and obs.done itself β do NOT return a tuple. | |
| return obs | |
| def state(self) -> NL2SQLState: | |
| return self._state | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def available_tasks(self) -> list: | |
| return all_task_names() |