nl2sql-bench / server /environment.py
ritvik360's picture
Upload folder using huggingface_hub
46e0615 verified
"""
nl2sql-bench/server/environment.py
====================================
NL2SQL-Bench core environment β€” implements the OpenEnv Environment interface.
Episode flow
------------
1. reset(task_name?) β†’ picks a task + question, returns initial observation
2. step(action) β†’ executes the SQL, grades it, returns observation + reward
3. state() β†’ returns episode metadata
4. Episode ends when: exact_match OR step count reaches max_steps
The environment manages its own SQLite connection (in-memory, seeded
deterministically). One connection per Environment instance; the FastAPI
server creates one Environment per WebSocket session.
"""
from __future__ import annotations
import os
import sqlite3
import uuid
from pathlib import Path
from typing import Optional
from openenv.core.env_server import Environment
# Import after openenv so path is correct regardless of working directory
_HERE = Path(__file__).parent
# Lazy import of task registry (avoids circular imports)
from tasks import get_task, all_task_names, BaseTask
from tasks.base import TaskExample
from grader import (
GradeResult,
compute_ground_truth,
execute_query,
grade,
has_order_by,
)
# We import our models from one level up (models.py at project root)
import sys
sys.path.insert(0, str(_HERE.parent))
from models import NL2SQLAction, NL2SQLObservation, NL2SQLState
# ── Constants ──────────────────────────────────────────────────────────────
DEFAULT_TASK = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter")
MAX_STEPS = int(os.getenv("NL2SQL_MAX_STEPS", "5"))
RESULT_LIMIT = 10 # Max rows shown to agent per step
class NL2SQLEnvironment(Environment):
"""
OpenEnv-compliant environment for NL-to-SQL query generation.
One instance per WebSocket session (created by create_fastapi_app).
"""
def __init__(self) -> None:
self._conn: Optional[sqlite3.Connection] = None
self._task: Optional[BaseTask] = None
self._example: Optional[TaskExample] = None
self._ground_truth: list = []
self._order_sensitive: bool = False
self._state = NL2SQLState(
episode_id=None,
step_count=0,
task_name="",
task_difficulty="",
question="",
best_reward=0.0,
cumulative_reward=0.0,
solved=False
)
self._last_obs = NL2SQLObservation(
question="",
schema_context="",
task_name="",
last_query="",
last_result=[],
last_error=None,
result_columns=[],
step=0,
max_steps=5,
done=False,
reward=None,
score=0.0
)
self._episode_rewards: list = []
self._setup_db()
# ── DB lifecycle ───────────────────────────────────────────────────────
def _setup_db(self) -> None:
"""Create in-memory SQLite DB and seed it."""
schema_path = _HERE / "db" / "schema.sql"
from db.seed import seed_database # local import after sys.path setup
conn = sqlite3.connect(":memory:", check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
conn.executescript(schema_path.read_text())
seed_database(conn)
self._conn = conn
# ── OpenEnv interface ──────────────────────────────────────────────────
def reset(self, task_name: Optional[str] = None) -> NL2SQLObservation:
"""
Start a new episode.
task_name: one of 'simple-filter', 'join-aggregation', 'analytics-window'.
Defaults to NL2SQL_DEFAULT_TASK env-var or 'simple-filter'.
"""
task_name = task_name or DEFAULT_TASK
if task_name not in all_task_names():
task_name = DEFAULT_TASK
self._task = get_task(task_name)
self._example = self._task.next_example()
self._order_sensitive = has_order_by(self._example.sql)
# Pre-compute ground truth once per episode
self._ground_truth = compute_ground_truth(self._conn, self._example.sql)
self._episode_rewards = []
self._state = NL2SQLState(
episode_id=str(uuid.uuid4()),
step_count=0,
task_name=self._task.name,
task_difficulty=self._task.difficulty,
question=self._example.question,
best_reward=0.0,
cumulative_reward=0.0,
solved=False,
)
obs = NL2SQLObservation(
question=self._example.question,
schema_context=self._task.schema_context(),
task_name=self._task.name,
last_query="",
last_result=[],
last_error=None,
result_columns=[],
step=0,
max_steps=MAX_STEPS,
done=False,
reward=None,
score=0.0,
)
self._last_obs = obs
return obs
def step(self, action: NL2SQLAction) -> NL2SQLObservation:
"""Execute the agent's SQL and return graded observation."""
if self._task is None or self._example is None:
# Called before reset β€” auto-reset
self.reset()
self._state.step_count += 1
current_step = self._state.step_count
done = False
# Execute the query
rows, error = execute_query(self._conn, action.query)
# Grade it
result: GradeResult = grade(
actual_rows=rows,
ground_truth_rows=self._ground_truth,
error=error,
step=current_step,
order_sensitive=self._order_sensitive,
)
reward = result.reward
self._episode_rewards.append(reward)
self._state.cumulative_reward += reward
self._state.best_reward = max(self._state.best_reward, reward)
if result.exact_match:
self._state.solved = True
done = True
elif current_step >= MAX_STEPS:
done = True
# Prepare result rows for observation (truncated for agent readability)
display_rows = (rows or [])[:RESULT_LIMIT]
result_columns = list(display_rows[0].keys()) if display_rows else []
# Convert sqlite3.Row objects if needed
display_rows = [dict(r) for r in display_rows]
# Normalised cumulative score
n = len(self._episode_rewards)
score = self._state.cumulative_reward / max(n, 1) if n else 0.0
score = round(min(max(score, 0.0), 1.0), 4)
obs = NL2SQLObservation(
question=self._example.question,
schema_context=self._task.schema_context(),
task_name=self._task.name,
last_query=action.query,
last_result=display_rows,
last_error=error,
result_columns=result_columns,
step=current_step,
max_steps=MAX_STEPS,
done=done,
reward=reward,
score=score,
)
self._last_obs = obs
# openenv-core expects ONLY the observation returned from step().
# The framework reads obs.reward and obs.done itself β€” do NOT return a tuple.
return obs
@property
def state(self) -> NL2SQLState:
return self._state
# ── Helpers ────────────────────────────────────────────────────────────
def available_tasks(self) -> list:
return all_task_names()