nl2sql-bench / data_factory /validator.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
data_factory/validator.py
==========================
SQL execution validation layer.
GUARANTEE: Every record that passes this validator has a SQL that:
1. Runs without error against the actual seeded SQLite schema
2. Returns at least one row (non-empty result)
3. Returns the expected column names
No LLM-generated SQL ever reaches this validator β€” SQL always comes from
the human-verified template library. This validator is an extra safety net
to catch any copy-paste or formatting regressions.
"""
from __future__ import annotations
import sqlite3
from dataclasses import dataclass, field
from typing import Any, Optional
from data_factory.schemas import build_connection, SCHEMA_CONTEXT
from data_factory.templates import Template
# ─────────────────────────────────────────────────────────────────────────────
# DATA CLASSES
# ─────────────────────────────────────────────────────────────────────────────
@dataclass
class ValidationResult:
passed: bool
sql: str
error: Optional[str] = None
row_count: int = 0
columns: list[str] = field(default_factory=list)
@dataclass
class DataRecord:
"""One training example ready to be written to JSONL/Parquet."""
domain: str
difficulty: str
sql: str
nl_question: str # The NL paraphrase used as prompt
persona: str # ceo | chatty | lazy_typist | non_techie | analyst | augmented
has_order: bool
schema_context: str
row_count: int # From validation run
columns: list[str] # From validation run
source: str # "template_base" | "vllm_persona" | "rule_augmented"
template_id: int # Index into ALL_TEMPLATES
def to_training_dict(self) -> dict[str, Any]:
"""
Returns the dictionary that will be written to the output dataset.
Format is compatible with TRL / HuggingFace `datasets`:
prompt : chat-format messages list (system + user)
sql : ground-truth SQL (label / reward reference)
metadata: auxiliary fields for curriculum or filtering
"""
system_msg = (
"You are an expert SQL analyst. "
"Write a single SELECT query that answers the question. "
"Output ONLY the SQL query β€” no markdown, no explanation, no backticks."
)
user_msg = (
f"DATABASE SCHEMA\n"
f"---------------\n"
f"{self.schema_context}\n\n"
f"QUESTION: {self.nl_question}"
)
return {
"prompt": [
{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg},
],
"sql": self.sql,
"metadata": {
"domain": self.domain,
"difficulty": self.difficulty,
"persona": self.persona,
"has_order": self.has_order,
"row_count": self.row_count,
"columns": self.columns,
"source": self.source,
"template_id": self.template_id,
},
}
# ─────────────────────────────────────────────────────────────────────────────
# VALIDATOR
# ─────────────────────────────────────────────────────────────────────────────
class SQLValidator:
"""
Validates SQL against a seeded in-memory SQLite connection.
One validator per domain to reuse the same connection for all templates
in that domain (performance optimization).
"""
def __init__(self, domain: str, seed: int = 42) -> None:
self.domain = domain
self._conn = build_connection(domain, seed=seed)
def validate(self, sql: str) -> ValidationResult:
"""
Execute SQL and return a ValidationResult.
Never raises β€” always returns a result object.
"""
sql = sql.strip().rstrip(";")
if not sql:
return ValidationResult(passed=False, sql=sql, error="Empty SQL string.")
# Block any write operations
first_word = sql.split()[0].lower() if sql.split() else ""
forbidden = {"insert","update","delete","drop","alter","create","replace","truncate","pragma"}
if first_word in forbidden:
return ValidationResult(
passed=False, sql=sql,
error=f"Write operation '{first_word.upper()}' is not permitted."
)
try:
cur = self._conn.execute(sql)
cols = [d[0] for d in cur.description] if cur.description else []
rows = cur.fetchall()
return ValidationResult(
passed=True,
sql=sql,
row_count=len(rows),
columns=cols,
)
except sqlite3.Error as exc:
return ValidationResult(passed=False, sql=sql, error=str(exc))
def close(self) -> None:
self._conn.close()
def validate_template(template: Template, seed: int = 42) -> ValidationResult:
"""Convenience function: validate a single template."""
v = SQLValidator(template["domain"], seed=seed)
result = v.validate(template["sql"])
v.close()
return result
def validate_all_templates(templates: list[Template], seed: int = 42) -> dict[str, Any]:
"""
Run validation across all templates. Returns a summary dict.
Used during CI / smoke testing.
"""
from data_factory.schemas import SCHEMA_MAP
validators = {domain: SQLValidator(domain, seed) for domain in SCHEMA_MAP}
passed = []
failed = []
for i, t in enumerate(templates):
v = validators[t["domain"]]
result = v.validate(t["sql"])
if result.passed:
passed.append(i)
else:
failed.append({"index": i, "domain": t["domain"],
"sql": t["sql"][:80], "error": result.error})
for v in validators.values():
v.close()
return {
"total": len(templates),
"passed": len(passed),
"failed": len(failed),
"failures": failed,
}
def build_record(
template: Template,
template_idx: int,
nl_question: str,
persona: str,
source: str,
validator: SQLValidator,
) -> Optional[DataRecord]:
"""
Validate the template SQL and, if it passes, build a DataRecord.
Parameters
----------
template : The source template (contains SQL, domain, difficulty).
template_idx : Index of template in ALL_TEMPLATES (for deduplication).
nl_question : The NL paraphrase to use as the prompt.
persona : Which persona/strategy generated this NL.
source : 'template_base' | 'vllm_persona' | 'rule_augmented'
validator : Pre-built SQLValidator for this domain.
Returns None if validation fails.
"""
vr = validator.validate(template["sql"])
if not vr.passed:
return None
return DataRecord(
domain=template["domain"],
difficulty=template["difficulty"],
sql=template["sql"],
nl_question=nl_question,
persona=persona,
has_order=template["has_order"],
schema_context=SCHEMA_CONTEXT[template["domain"]],
row_count=vr.row_count,
columns=vr.columns,
source=source,
template_id=template_idx,
)