queryforge / tasks.py
Prithvigg's picture
Upload folder using huggingface_hub
039839b verified
"""
SQL task definitions and runtime task registry for the QueryForge environment.
Built-in tasks:
easy β€” fix three misspelled SQL keywords
medium β€” fix a cartesian JOIN producing wrong results
hard β€” rewrite a correlated subquery as a CTE
Custom tasks can be added at runtime via REGISTRY.register() or
POST /tasks on the running server.
"""
import json
from dataclasses import dataclass
from pathlib import Path
from threading import Lock
from typing import Any, Dict, List, Optional
# ── Data classes ──────────────────────────────────────────────────────────────
@dataclass
class TestCase:
"""A single test case: expected output rows for correctness grading."""
description: str
expected_rows: List[Dict[str, Any]]
order_by: Optional[str] = None # comma-separated columns to sort by
@dataclass
class SQLTask:
"""Full definition of one SQL challenge."""
id: str
level: str # "easy" | "medium" | "hard" | "custom"
title: str
description: str
schema_ddl: str # DDL + seed INSERT statements for DuckDB
broken_query: str # broken/slow query the agent must fix
error_message: str # error or performance warning shown to agent
hint: str
test_cases: List[TestCase]
solution_query: str # reference solution used by the AI judge
max_steps: int = 5
# ── Built-in tasks ────────────────────────────────────────────────────────────
_TASK_EASY = SQLTask(
id="task_easy_syntax",
level="easy",
title="Fix the Syntax Errors",
description="""\
TASK: Fix the syntax errors in the query below so it runs correctly.
SCHEMA:
users(id INTEGER, name VARCHAR, age INTEGER, city VARCHAR)
BROKEN QUERY:
SELEC name, age FORM users WEHRE age > 30 AND city = 'New York'
ERROR:
Parser Error: syntax error at or near "SELEC"
GOAL: Return a valid SQL query that retrieves `name` and `age`
of users who are older than 30 AND live in New York.
Order by name ASC.""",
schema_ddl="""\
CREATE TABLE users (
id INTEGER,
name VARCHAR,
age INTEGER,
city VARCHAR
);
INSERT INTO users VALUES
(1, 'Alice', 35, 'New York'),
(2, 'Bob', 28, 'New York'),
(3, 'Carol', 42, 'Chicago'),
(4, 'Dave', 31, 'New York'),
(5, 'Eve', 25, 'New York'),
(6, 'Frank', 38, 'New York');
""",
broken_query="SELEC name, age FORM users WEHRE age > 30 AND city = 'New York'",
error_message='Parser Error: syntax error at or near "SELEC"',
hint="Three SQL keywords are misspelled: SELEC β†’ SELECT, FORM β†’ FROM, WEHRE β†’ WHERE.",
test_cases=[
TestCase(
description="Users over 30 living in New York, ordered by name",
expected_rows=[
{"name": "Alice", "age": 35},
{"name": "Dave", "age": 31},
{"name": "Frank", "age": 38},
],
order_by="name",
)
],
solution_query=(
"SELECT name, age FROM users "
"WHERE age > 30 AND city = 'New York' "
"ORDER BY name ASC"
),
)
_TASK_MEDIUM = SQLTask(
id="task_medium_join",
level="medium",
title="Fix the Cartesian JOIN",
description="""\
TASK: The query below produces wildly inflated totals because a JOIN condition
is missing, creating a cartesian product with the `products` table. Fix it.
SCHEMAS:
users(id INTEGER, name VARCHAR, age INTEGER)
products(id INTEGER, title VARCHAR, price DECIMAL)
orders(id INTEGER, user_id INTEGER, product_id INTEGER, amount DECIMAL)
BROKEN QUERY:
SELECT u.name, p.title, SUM(o.amount) AS total_spent
FROM orders o, users u, products p
WHERE o.user_id = u.id
GROUP BY u.name, p.title
ORDER BY total_spent DESC
PROBLEM:
Missing join condition `o.product_id = p.id`.
Every order row is multiplied by ALL products, inflating every total by 3Γ—.
GOAL: Rewrite using explicit INNER JOIN … ON syntax with all correct join
conditions. Return user name, product title, and true total amount spent per
(user, product) pair, ordered by total_spent DESC.""",
schema_ddl="""\
CREATE TABLE users (id INTEGER, name VARCHAR, age INTEGER);
CREATE TABLE products (id INTEGER, title VARCHAR, price DECIMAL);
CREATE TABLE orders (id INTEGER, user_id INTEGER, product_id INTEGER, amount DECIMAL);
INSERT INTO users VALUES (1,'Alice',30),(2,'Bob',25),(3,'Carol',35);
INSERT INTO products VALUES (1,'Laptop',999.99),(2,'Phone',599.99),(3,'Tablet',399.99);
INSERT INTO orders VALUES
(1,1,1,999.99),(2,1,2,599.99),
(3,2,1,999.99),(4,2,3,399.99),
(5,3,2,599.99),(6,3,1,999.99);
""",
broken_query="""\
SELECT u.name, p.title, SUM(o.amount) AS total_spent
FROM orders o, users u, products p
WHERE o.user_id = u.id
GROUP BY u.name, p.title
ORDER BY total_spent DESC""",
error_message=(
"Query runs but produces WRONG results: totals are 3Γ— too high "
"because every order is joined to every product (cartesian product)."
),
hint=(
"Use INNER JOIN … ON for every table. "
"You need both: o.user_id = u.id AND o.product_id = p.id."
),
test_cases=[
TestCase(
description="Correct per-(user, product) totals",
expected_rows=[
{"name": "Alice", "title": "Laptop", "total_spent": 999.99},
{"name": "Alice", "title": "Phone", "total_spent": 599.99},
{"name": "Bob", "title": "Laptop", "total_spent": 999.99},
{"name": "Bob", "title": "Tablet", "total_spent": 399.99},
{"name": "Carol", "title": "Laptop", "total_spent": 999.99},
{"name": "Carol", "title": "Phone", "total_spent": 599.99},
],
order_by="name,title",
)
],
solution_query="""\
SELECT u.name, p.title, SUM(o.amount) AS total_spent
FROM orders o
INNER JOIN users u ON o.user_id = u.id
INNER JOIN products p ON o.product_id = p.id
GROUP BY u.name, p.title
ORDER BY total_spent DESC""",
)
_TASK_HARD = SQLTask(
id="task_hard_cte",
level="hard",
title="Rewrite Correlated Subquery as CTE",
description="""\
TASK: The query below is semantically correct but executes the inner AVG(salary)
once per employee row β€” O(N) full scans. Rewrite it using a WITH (CTE) so the
department averages are computed exactly once.
SCHEMAS:
departments(id INTEGER, dept_name VARCHAR)
employees(id INTEGER, name VARCHAR, department_id INTEGER, salary DECIMAL)
SLOW QUERY:
SELECT e.name, e.department_id, e.salary
FROM employees e
WHERE e.salary > (
SELECT AVG(e2.salary)
FROM employees e2
WHERE e2.department_id = e.department_id
)
ORDER BY e.department_id, e.salary DESC
PERFORMANCE WARNING:
For 1 M employees the inner subquery executes 1 M times.
DuckDB's EXPLAIN shows: 'FILTER ... (subquery)' with nested loop.
GOAL: Rewrite using a CTE that computes per-department average salary once,
then join it to employees and filter. The result must be identical:
employees who earn strictly above their own department's average salary,
ordered by department_id ASC, salary DESC.""",
schema_ddl="""\
CREATE TABLE departments (id INTEGER, dept_name VARCHAR);
CREATE TABLE employees (id INTEGER, name VARCHAR, department_id INTEGER, salary DECIMAL);
INSERT INTO departments VALUES (1,'Engineering'),(2,'Marketing'),(3,'Sales');
INSERT INTO employees VALUES
(1,'Alice', 1, 95000),(2,'Bob', 1, 75000),(3,'Carol', 1, 85000),
(4,'Dave', 2, 65000),(5,'Eve', 2, 70000),(6,'Frank', 2, 60000),
(7,'Grace', 3, 55000),(8,'Hank', 3, 72000),(9,'Iris', 3, 58000);
""",
broken_query="""\
SELECT e.name, e.department_id, e.salary
FROM employees e
WHERE e.salary > (
SELECT AVG(e2.salary)
FROM employees e2
WHERE e2.department_id = e.department_id
)
ORDER BY e.department_id, e.salary DESC""",
error_message=(
"PERFORMANCE: Correlated subquery re-executes AVG() for every row. "
"On large tables this is O(NΒ²). Rewrite as a CTE for O(N) execution."
),
hint=(
"WITH dept_avg AS (SELECT department_id, AVG(salary) AS avg_salary "
"FROM employees GROUP BY department_id) β€” then JOIN employees to dept_avg "
"and filter WHERE e.salary > d.avg_salary."
),
test_cases=[
TestCase(
description="Employees strictly above their department's average salary",
expected_rows=[
{"name": "Alice", "department_id": 1, "salary": 95000.0},
{"name": "Eve", "department_id": 2, "salary": 70000.0},
{"name": "Hank", "department_id": 3, "salary": 72000.0},
],
order_by="department_id,name",
)
],
solution_query="""\
WITH dept_avg AS (
SELECT department_id, AVG(salary) AS avg_salary
FROM employees
GROUP BY department_id
)
SELECT e.name, e.department_id, e.salary
FROM employees e
JOIN dept_avg d ON e.department_id = d.department_id
WHERE e.salary > d.avg_salary
ORDER BY e.department_id, e.salary DESC""",
max_steps=6,
)
# ── Expert tasks ──────────────────────────────────────────────────────────────
_TASK_EXPERT_RANK = SQLTask(
id="task_expert_rank",
level="expert",
title="Fix the Tie-Breaking Window Function",
description="""\
TASK: The query below attempts to find the top-earning sales rep per region,
but it returns wrong results. Debug it.
SCHEMA:
sales_reps(id INTEGER, name VARCHAR, region VARCHAR, revenue DECIMAL)
BROKEN QUERY:
SELECT name, region, revenue
FROM (
SELECT name, region, revenue,
ROW_NUMBER() OVER (PARTITION BY region ORDER BY revenue ASC) AS rn
FROM sales_reps
) ranked
WHERE rn = 1
ORDER BY region, name
PROBLEM:
The query returns 2 rows but the expected answer has 4.
The output values are also wrong β€” it seems to pick the lowest revenue per region
instead of the highest.
GOAL: Return ALL reps whose revenue is the highest in their region.
Order by region ASC, name ASC.""",
schema_ddl="""\
CREATE TABLE sales_reps (id INTEGER, name VARCHAR, region VARCHAR, revenue DECIMAL);
INSERT INTO sales_reps VALUES
(1, 'Alice', 'North', 95000),
(2, 'Bob', 'North', 87000),
(3, 'Carol', 'North', 95000),
(4, 'Dave', 'South', 88000),
(5, 'Eve', 'South', 88000),
(6, 'Frank', 'South', 75000);
""",
broken_query="""\
SELECT name, region, revenue
FROM (
SELECT name, region, revenue,
ROW_NUMBER() OVER (PARTITION BY region ORDER BY revenue ASC) AS rn
FROM sales_reps
) ranked
WHERE rn = 1
ORDER BY region, name""",
error_message=(
"Query runs but returns wrong results: only 2 rows (one per region) "
"with the LOWEST revenue instead of the HIGHEST. Expected 4 rows."
),
hint="There are two bugs. Think about both the ranking function and the sort order.",
test_cases=[
TestCase(
description="All reps tied at rank 1 per region",
expected_rows=[
{"name": "Alice", "region": "North", "revenue": 95000.0},
{"name": "Carol", "region": "North", "revenue": 95000.0},
{"name": "Dave", "region": "South", "revenue": 88000.0},
{"name": "Eve", "region": "South", "revenue": 88000.0},
],
order_by="region,name",
)
],
solution_query="""\
SELECT name, region, revenue
FROM (
SELECT name, region, revenue,
RANK() OVER (PARTITION BY region ORDER BY revenue DESC) AS rk
FROM sales_reps
) ranked
WHERE rk = 1
ORDER BY region, name""",
max_steps=6,
)
_TASK_EXPERT_RECURSIVE = SQLTask(
id="task_expert_recursive",
level="expert",
title="Traverse Org Chart with Recursive CTE",
description="""\
TASK: The query below attempts to find all subordinates of the VP of Engineering
(id=3), but it returns wrong results. Debug and fix it.
SCHEMA:
employees(id INTEGER, name VARCHAR, manager_id INTEGER)
DATA (partial):
CEO (id=1)
VP Eng (id=3, reports to CEO)
Lead A (id=5), Lead B (id=6) report to VP Eng
Dev 1..4 (id=8..11) report to Leads
Junior 1..2 (id=13..14) report to Dev 1
BROKEN QUERY:
WITH direct AS (
SELECT id, name, manager_id FROM employees WHERE id = 3
),
level2 AS (
SELECT e.id, e.name, e.manager_id
FROM employees e
INNER JOIN direct d ON e.manager_id = d.id
)
SELECT id, name, manager_id FROM direct
UNION ALL
SELECT id, name, manager_id FROM level2
ORDER BY id
PROBLEM:
The query returns some results but the row count and values don't match
the expected output. Inspect what the anchor condition selects and whether
the query reaches all depths of the org tree.
GOAL: Return ALL 8 subordinates of VP Eng (id=3) at any depth.
Do NOT include VP Eng himself β€” only his reports.
Return id, name, manager_id columns, ordered by id ASC.""",
schema_ddl="""\
CREATE TABLE employees (id INTEGER, name VARCHAR, manager_id INTEGER);
INSERT INTO employees VALUES
(1, 'CEO', NULL),
(2, 'CFO', 1),
(3, 'VP Eng', 1),
(4, 'VP Sales', 1),
(5, 'Lead A', 3),
(6, 'Lead B', 3),
(7, 'Sales Mgr',4),
(8, 'Dev 1', 5),
(9, 'Dev 2', 5),
(10, 'Dev 3', 6),
(11, 'Dev 4', 6),
(12, 'Sales Rep',7),
(13, 'Junior 1', 8),
(14, 'Junior 2', 8);
""",
broken_query="""\
WITH direct AS (
SELECT id, name, manager_id FROM employees WHERE id = 3
),
level2 AS (
SELECT e.id, e.name, e.manager_id
FROM employees e
INNER JOIN direct d ON e.manager_id = d.id
)
SELECT id, name, manager_id FROM direct
UNION ALL
SELECT id, name, manager_id FROM level2
ORDER BY id""",
error_message=(
"Query returns wrong results. Check carefully: does the anchor condition "
"select the right starting rows? Does the query traverse all depths?"
),
hint="There are multiple issues. Think about what the anchor selects and how deep the query reaches.",
test_cases=[
TestCase(
description="All 8 subordinates of VP Eng at any depth",
expected_rows=[
{"id": 5, "name": "Lead A", "manager_id": 3},
{"id": 6, "name": "Lead B", "manager_id": 3},
{"id": 8, "name": "Dev 1", "manager_id": 5},
{"id": 9, "name": "Dev 2", "manager_id": 5},
{"id": 10, "name": "Dev 3", "manager_id": 6},
{"id": 11, "name": "Dev 4", "manager_id": 6},
{"id": 13, "name": "Junior 1", "manager_id": 8},
{"id": 14, "name": "Junior 2", "manager_id": 8},
],
order_by="id",
)
],
solution_query="""\
WITH RECURSIVE subordinates AS (
SELECT id, name, manager_id
FROM employees
WHERE manager_id = 3
UNION ALL
SELECT e.id, e.name, e.manager_id
FROM employees e
INNER JOIN subordinates s ON e.manager_id = s.id
)
SELECT id, name, manager_id
FROM subordinates
ORDER BY id""",
max_steps=7,
)
_TASK_EXPERT_WINDOW = SQLTask(
id="task_expert_window",
level="expert",
title="Fix Broken Window Functions: Running Total and Revenue Rank",
description="""\
TASK: The query below computes a cumulative running total and a within-region
revenue rank for each quarter, but the results are wrong. Debug and fix it.
SCHEMA:
quarterly_sales(region VARCHAR, quarter INTEGER, revenue DECIMAL)
DATA:
East: Q1=15000, Q2=18000, Q3=12000, Q4=20000
West: Q1=11000, Q2=14000, Q3=16000, Q4=16000 (note: Q3 and Q4 are tied)
BROKEN QUERY:
SELECT region, quarter, revenue,
SUM(revenue) OVER (ORDER BY region, quarter) AS running_total,
RANK() OVER (ORDER BY revenue DESC) AS revenue_rank
FROM quarterly_sales
ORDER BY region, quarter
PROBLEM:
The query returns wrong values for both running_total and revenue_rank.
Compare your output against the expected results carefully.
GOAL: running_total should be a cumulative sum per region (reset each region,
ordered by quarter). revenue_rank should rank revenue within each region
(ordered by revenue DESC), handling ties correctly (tied values must get
the same rank).
Final output: ORDER BY region ASC, quarter ASC.""",
schema_ddl="""\
CREATE TABLE quarterly_sales (region VARCHAR, quarter INTEGER, revenue DECIMAL);
INSERT INTO quarterly_sales VALUES
('East', 1, 15000),
('East', 2, 18000),
('East', 3, 12000),
('East', 4, 20000),
('West', 1, 11000),
('West', 2, 14000),
('West', 3, 16000),
('West', 4, 16000);
""",
broken_query="""\
SELECT region, quarter, revenue,
SUM(revenue) OVER (ORDER BY region, quarter) AS running_total,
RANK() OVER (ORDER BY revenue DESC) AS revenue_rank
FROM quarterly_sales
ORDER BY region, quarter""",
error_message=(
"Query runs but both computed columns are wrong. "
"running_total does not reset per region. "
"revenue_rank is a global ranking across all rows instead of per-region."
),
hint="Multiple issues exist. Think about partitioning and how tied values should be ranked.",
test_cases=[
TestCase(
description="Per-region running total and within-region revenue rank with ties",
expected_rows=[
{"region": "East", "quarter": 1, "revenue": 15000.0, "running_total": 15000.0, "revenue_rank": 3},
{"region": "East", "quarter": 2, "revenue": 18000.0, "running_total": 33000.0, "revenue_rank": 2},
{"region": "East", "quarter": 3, "revenue": 12000.0, "running_total": 45000.0, "revenue_rank": 4},
{"region": "East", "quarter": 4, "revenue": 20000.0, "running_total": 65000.0, "revenue_rank": 1},
{"region": "West", "quarter": 1, "revenue": 11000.0, "running_total": 11000.0, "revenue_rank": 4},
{"region": "West", "quarter": 2, "revenue": 14000.0, "running_total": 25000.0, "revenue_rank": 3},
{"region": "West", "quarter": 3, "revenue": 16000.0, "running_total": 41000.0, "revenue_rank": 1},
{"region": "West", "quarter": 4, "revenue": 16000.0, "running_total": 57000.0, "revenue_rank": 1},
],
order_by="region,quarter",
)
],
solution_query="""\
SELECT region, quarter, revenue,
SUM(revenue) OVER (PARTITION BY region ORDER BY quarter) AS running_total,
RANK() OVER (PARTITION BY region ORDER BY revenue DESC) AS revenue_rank
FROM quarterly_sales
ORDER BY region, quarter""",
max_steps=6,
)
# ── Task Registry ─────────────────────────────────────────────────────────────
class TaskRegistry:
"""
Thread-safe registry of SQL tasks, shared across all environment sessions.
Built-in tasks (easy / medium / hard) are always present and cannot be removed.
Custom tasks can be added via register(), load_from_json(), or POST /tasks.
"""
_BUILTIN_IDS: frozenset = frozenset([
"task_easy_syntax", "task_medium_join", "task_hard_cte",
"task_expert_rank", "task_expert_recursive", "task_expert_window",
])
def __init__(self, initial_tasks: List[SQLTask]) -> None:
self._lock = Lock()
# Insertion-ordered dict preserves cycling order
self._tasks: Dict[str, SQLTask] = {t.id: t for t in initial_tasks}
self._cycle_index: int = 0
# ── CRUD ─────────────────────────────────────────────────────────────────
def register(self, task: SQLTask) -> None:
"""Add or replace a task. Replaces silently if the ID already exists."""
with self._lock:
self._tasks[task.id] = task
def unregister(self, task_id: str) -> None:
"""
Remove a custom task.
Raises ValueError for built-in tasks, KeyError if not found.
"""
if task_id in self._BUILTIN_IDS:
raise ValueError(f"Built-in task '{task_id}' cannot be removed.")
with self._lock:
if task_id not in self._tasks:
raise KeyError(task_id)
del self._tasks[task_id]
def get(self, task_id: str) -> SQLTask:
"""Return a task by ID. Raises KeyError with available IDs if not found."""
with self._lock:
if task_id not in self._tasks:
available = ", ".join(self._tasks.keys())
raise KeyError(
f"Task '{task_id}' not found. "
f"Available: {available}"
)
return self._tasks[task_id]
def list_all(self) -> List[SQLTask]:
"""Return all registered tasks in insertion order."""
with self._lock:
return list(self._tasks.values())
def ids(self) -> List[str]:
"""Return all task IDs in insertion order."""
with self._lock:
return list(self._tasks.keys())
# ── Cycling ───────────────────────────────────────────────────────────────
def cycle_next(self) -> SQLTask:
"""Return the next task in round-robin order (wraps at end)."""
with self._lock:
tasks = list(self._tasks.values())
task = tasks[self._cycle_index % len(tasks)]
self._cycle_index += 1
return task
# ── Bulk loading ──────────────────────────────────────────────────────────
def load_from_json(self, path: str) -> int:
"""
Load tasks from a JSON file (list of task spec objects).
Returns the number of tasks loaded.
Minimal required fields per task:
id, schema_ddl, expected_rows
Example file::
[
{
"id": "my_null_task",
"level": "medium",
"title": "Handle NULLs in aggregation",
"schema_ddl": "CREATE TABLE ...; INSERT ...",
"broken_query": "SELECT AVG(score) FROM ...",
"expected_rows": [{"avg_score": 72.5}],
"hint": "Use COALESCE to handle NULL scores."
}
]
"""
raw = json.loads(Path(path).read_text())
if isinstance(raw, dict):
raw = [raw]
for item in raw:
self.register(task_from_dict(item))
return len(raw)
# ── Helpers ───────────────────────────────────────────────────────────────
def __len__(self) -> int:
with self._lock:
return len(self._tasks)
def __contains__(self, task_id: str) -> bool:
with self._lock:
return task_id in self._tasks
# ── Conversion helper ─────────────────────────────────────────────────────────
def task_from_dict(d: Dict[str, Any]) -> SQLTask:
"""
Construct an SQLTask from a plain dict (JSON payload or loaded file).
Required keys : id, schema_ddl, expected_rows
Optional keys : level, title, description, broken_query, error_message,
hint, order_by, solution_query, test_description, max_steps
"""
return SQLTask(
id=d["id"],
level=d.get("level", "custom"),
title=d.get("title", d["id"]),
description=d.get("description", ""),
schema_ddl=d["schema_ddl"],
broken_query=d.get("broken_query", ""),
error_message=d.get("error_message", ""),
hint=d.get("hint", ""),
test_cases=[
TestCase(
description=d.get("test_description", "Custom test case"),
expected_rows=d["expected_rows"],
order_by=d.get("order_by"),
)
],
solution_query=d.get("solution_query", ""),
max_steps=d.get("max_steps", 5),
)
# ── Global singleton ──────────────────────────────────────────────────────────
REGISTRY = TaskRegistry([
_TASK_EASY, _TASK_MEDIUM, _TASK_HARD,
_TASK_EXPERT_RANK, _TASK_EXPERT_RECURSIVE, _TASK_EXPERT_WINDOW,
])
# Backwards-compat: snapshot of all built-in tasks at import time
TASKS: List[SQLTask] = [
_TASK_EASY, _TASK_MEDIUM, _TASK_HARD,
_TASK_EXPERT_RANK, _TASK_EXPERT_RECURSIVE, _TASK_EXPERT_WINDOW,
]
TASK_BY_ID: Dict[str, SQLTask] = {t.id: t for t in TASKS}