Spaces:
Sleeping
Sleeping
| """ | |
| SQL task definitions and runtime task registry for the QueryForge environment. | |
| Built-in tasks: | |
| easy β fix three misspelled SQL keywords | |
| medium β fix a cartesian JOIN producing wrong results | |
| hard β rewrite a correlated subquery as a CTE | |
| Custom tasks can be added at runtime via REGISTRY.register() or | |
| POST /tasks on the running server. | |
| """ | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from threading import Lock | |
| from typing import Any, Dict, List, Optional | |
| # ββ Data classes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestCase: | |
| """A single test case: expected output rows for correctness grading.""" | |
| description: str | |
| expected_rows: List[Dict[str, Any]] | |
| order_by: Optional[str] = None # comma-separated columns to sort by | |
| class SQLTask: | |
| """Full definition of one SQL challenge.""" | |
| id: str | |
| level: str # "easy" | "medium" | "hard" | "custom" | |
| title: str | |
| description: str | |
| schema_ddl: str # DDL + seed INSERT statements for DuckDB | |
| broken_query: str # broken/slow query the agent must fix | |
| error_message: str # error or performance warning shown to agent | |
| hint: str | |
| test_cases: List[TestCase] | |
| solution_query: str # reference solution used by the AI judge | |
| max_steps: int = 5 | |
| # ββ Built-in tasks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _TASK_EASY = SQLTask( | |
| id="task_easy_syntax", | |
| level="easy", | |
| title="Fix the Syntax Errors", | |
| description="""\ | |
| TASK: Fix the syntax errors in the query below so it runs correctly. | |
| SCHEMA: | |
| users(id INTEGER, name VARCHAR, age INTEGER, city VARCHAR) | |
| BROKEN QUERY: | |
| SELEC name, age FORM users WEHRE age > 30 AND city = 'New York' | |
| ERROR: | |
| Parser Error: syntax error at or near "SELEC" | |
| GOAL: Return a valid SQL query that retrieves `name` and `age` | |
| of users who are older than 30 AND live in New York. | |
| Order by name ASC.""", | |
| schema_ddl="""\ | |
| CREATE TABLE users ( | |
| id INTEGER, | |
| name VARCHAR, | |
| age INTEGER, | |
| city VARCHAR | |
| ); | |
| INSERT INTO users VALUES | |
| (1, 'Alice', 35, 'New York'), | |
| (2, 'Bob', 28, 'New York'), | |
| (3, 'Carol', 42, 'Chicago'), | |
| (4, 'Dave', 31, 'New York'), | |
| (5, 'Eve', 25, 'New York'), | |
| (6, 'Frank', 38, 'New York'); | |
| """, | |
| broken_query="SELEC name, age FORM users WEHRE age > 30 AND city = 'New York'", | |
| error_message='Parser Error: syntax error at or near "SELEC"', | |
| hint="Three SQL keywords are misspelled: SELEC β SELECT, FORM β FROM, WEHRE β WHERE.", | |
| test_cases=[ | |
| TestCase( | |
| description="Users over 30 living in New York, ordered by name", | |
| expected_rows=[ | |
| {"name": "Alice", "age": 35}, | |
| {"name": "Dave", "age": 31}, | |
| {"name": "Frank", "age": 38}, | |
| ], | |
| order_by="name", | |
| ) | |
| ], | |
| solution_query=( | |
| "SELECT name, age FROM users " | |
| "WHERE age > 30 AND city = 'New York' " | |
| "ORDER BY name ASC" | |
| ), | |
| ) | |
| _TASK_MEDIUM = SQLTask( | |
| id="task_medium_join", | |
| level="medium", | |
| title="Fix the Cartesian JOIN", | |
| description="""\ | |
| TASK: The query below produces wildly inflated totals because a JOIN condition | |
| is missing, creating a cartesian product with the `products` table. Fix it. | |
| SCHEMAS: | |
| users(id INTEGER, name VARCHAR, age INTEGER) | |
| products(id INTEGER, title VARCHAR, price DECIMAL) | |
| orders(id INTEGER, user_id INTEGER, product_id INTEGER, amount DECIMAL) | |
| BROKEN QUERY: | |
| SELECT u.name, p.title, SUM(o.amount) AS total_spent | |
| FROM orders o, users u, products p | |
| WHERE o.user_id = u.id | |
| GROUP BY u.name, p.title | |
| ORDER BY total_spent DESC | |
| PROBLEM: | |
| Missing join condition `o.product_id = p.id`. | |
| Every order row is multiplied by ALL products, inflating every total by 3Γ. | |
| GOAL: Rewrite using explicit INNER JOIN β¦ ON syntax with all correct join | |
| conditions. Return user name, product title, and true total amount spent per | |
| (user, product) pair, ordered by total_spent DESC.""", | |
| schema_ddl="""\ | |
| CREATE TABLE users (id INTEGER, name VARCHAR, age INTEGER); | |
| CREATE TABLE products (id INTEGER, title VARCHAR, price DECIMAL); | |
| CREATE TABLE orders (id INTEGER, user_id INTEGER, product_id INTEGER, amount DECIMAL); | |
| INSERT INTO users VALUES (1,'Alice',30),(2,'Bob',25),(3,'Carol',35); | |
| INSERT INTO products VALUES (1,'Laptop',999.99),(2,'Phone',599.99),(3,'Tablet',399.99); | |
| INSERT INTO orders VALUES | |
| (1,1,1,999.99),(2,1,2,599.99), | |
| (3,2,1,999.99),(4,2,3,399.99), | |
| (5,3,2,599.99),(6,3,1,999.99); | |
| """, | |
| broken_query="""\ | |
| SELECT u.name, p.title, SUM(o.amount) AS total_spent | |
| FROM orders o, users u, products p | |
| WHERE o.user_id = u.id | |
| GROUP BY u.name, p.title | |
| ORDER BY total_spent DESC""", | |
| error_message=( | |
| "Query runs but produces WRONG results: totals are 3Γ too high " | |
| "because every order is joined to every product (cartesian product)." | |
| ), | |
| hint=( | |
| "Use INNER JOIN β¦ ON for every table. " | |
| "You need both: o.user_id = u.id AND o.product_id = p.id." | |
| ), | |
| test_cases=[ | |
| TestCase( | |
| description="Correct per-(user, product) totals", | |
| expected_rows=[ | |
| {"name": "Alice", "title": "Laptop", "total_spent": 999.99}, | |
| {"name": "Alice", "title": "Phone", "total_spent": 599.99}, | |
| {"name": "Bob", "title": "Laptop", "total_spent": 999.99}, | |
| {"name": "Bob", "title": "Tablet", "total_spent": 399.99}, | |
| {"name": "Carol", "title": "Laptop", "total_spent": 999.99}, | |
| {"name": "Carol", "title": "Phone", "total_spent": 599.99}, | |
| ], | |
| order_by="name,title", | |
| ) | |
| ], | |
| solution_query="""\ | |
| SELECT u.name, p.title, SUM(o.amount) AS total_spent | |
| FROM orders o | |
| INNER JOIN users u ON o.user_id = u.id | |
| INNER JOIN products p ON o.product_id = p.id | |
| GROUP BY u.name, p.title | |
| ORDER BY total_spent DESC""", | |
| ) | |
| _TASK_HARD = SQLTask( | |
| id="task_hard_cte", | |
| level="hard", | |
| title="Rewrite Correlated Subquery as CTE", | |
| description="""\ | |
| TASK: The query below is semantically correct but executes the inner AVG(salary) | |
| once per employee row β O(N) full scans. Rewrite it using a WITH (CTE) so the | |
| department averages are computed exactly once. | |
| SCHEMAS: | |
| departments(id INTEGER, dept_name VARCHAR) | |
| employees(id INTEGER, name VARCHAR, department_id INTEGER, salary DECIMAL) | |
| SLOW QUERY: | |
| SELECT e.name, e.department_id, e.salary | |
| FROM employees e | |
| WHERE e.salary > ( | |
| SELECT AVG(e2.salary) | |
| FROM employees e2 | |
| WHERE e2.department_id = e.department_id | |
| ) | |
| ORDER BY e.department_id, e.salary DESC | |
| PERFORMANCE WARNING: | |
| For 1 M employees the inner subquery executes 1 M times. | |
| DuckDB's EXPLAIN shows: 'FILTER ... (subquery)' with nested loop. | |
| GOAL: Rewrite using a CTE that computes per-department average salary once, | |
| then join it to employees and filter. The result must be identical: | |
| employees who earn strictly above their own department's average salary, | |
| ordered by department_id ASC, salary DESC.""", | |
| schema_ddl="""\ | |
| CREATE TABLE departments (id INTEGER, dept_name VARCHAR); | |
| CREATE TABLE employees (id INTEGER, name VARCHAR, department_id INTEGER, salary DECIMAL); | |
| INSERT INTO departments VALUES (1,'Engineering'),(2,'Marketing'),(3,'Sales'); | |
| INSERT INTO employees VALUES | |
| (1,'Alice', 1, 95000),(2,'Bob', 1, 75000),(3,'Carol', 1, 85000), | |
| (4,'Dave', 2, 65000),(5,'Eve', 2, 70000),(6,'Frank', 2, 60000), | |
| (7,'Grace', 3, 55000),(8,'Hank', 3, 72000),(9,'Iris', 3, 58000); | |
| """, | |
| broken_query="""\ | |
| SELECT e.name, e.department_id, e.salary | |
| FROM employees e | |
| WHERE e.salary > ( | |
| SELECT AVG(e2.salary) | |
| FROM employees e2 | |
| WHERE e2.department_id = e.department_id | |
| ) | |
| ORDER BY e.department_id, e.salary DESC""", | |
| error_message=( | |
| "PERFORMANCE: Correlated subquery re-executes AVG() for every row. " | |
| "On large tables this is O(NΒ²). Rewrite as a CTE for O(N) execution." | |
| ), | |
| hint=( | |
| "WITH dept_avg AS (SELECT department_id, AVG(salary) AS avg_salary " | |
| "FROM employees GROUP BY department_id) β then JOIN employees to dept_avg " | |
| "and filter WHERE e.salary > d.avg_salary." | |
| ), | |
| test_cases=[ | |
| TestCase( | |
| description="Employees strictly above their department's average salary", | |
| expected_rows=[ | |
| {"name": "Alice", "department_id": 1, "salary": 95000.0}, | |
| {"name": "Eve", "department_id": 2, "salary": 70000.0}, | |
| {"name": "Hank", "department_id": 3, "salary": 72000.0}, | |
| ], | |
| order_by="department_id,name", | |
| ) | |
| ], | |
| solution_query="""\ | |
| WITH dept_avg AS ( | |
| SELECT department_id, AVG(salary) AS avg_salary | |
| FROM employees | |
| GROUP BY department_id | |
| ) | |
| SELECT e.name, e.department_id, e.salary | |
| FROM employees e | |
| JOIN dept_avg d ON e.department_id = d.department_id | |
| WHERE e.salary > d.avg_salary | |
| ORDER BY e.department_id, e.salary DESC""", | |
| max_steps=6, | |
| ) | |
| # ββ Expert tasks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _TASK_EXPERT_RANK = SQLTask( | |
| id="task_expert_rank", | |
| level="expert", | |
| title="Fix the Tie-Breaking Window Function", | |
| description="""\ | |
| TASK: The query below attempts to find the top-earning sales rep per region, | |
| but it returns wrong results. Debug it. | |
| SCHEMA: | |
| sales_reps(id INTEGER, name VARCHAR, region VARCHAR, revenue DECIMAL) | |
| BROKEN QUERY: | |
| SELECT name, region, revenue | |
| FROM ( | |
| SELECT name, region, revenue, | |
| ROW_NUMBER() OVER (PARTITION BY region ORDER BY revenue ASC) AS rn | |
| FROM sales_reps | |
| ) ranked | |
| WHERE rn = 1 | |
| ORDER BY region, name | |
| PROBLEM: | |
| The query returns 2 rows but the expected answer has 4. | |
| The output values are also wrong β it seems to pick the lowest revenue per region | |
| instead of the highest. | |
| GOAL: Return ALL reps whose revenue is the highest in their region. | |
| Order by region ASC, name ASC.""", | |
| schema_ddl="""\ | |
| CREATE TABLE sales_reps (id INTEGER, name VARCHAR, region VARCHAR, revenue DECIMAL); | |
| INSERT INTO sales_reps VALUES | |
| (1, 'Alice', 'North', 95000), | |
| (2, 'Bob', 'North', 87000), | |
| (3, 'Carol', 'North', 95000), | |
| (4, 'Dave', 'South', 88000), | |
| (5, 'Eve', 'South', 88000), | |
| (6, 'Frank', 'South', 75000); | |
| """, | |
| broken_query="""\ | |
| SELECT name, region, revenue | |
| FROM ( | |
| SELECT name, region, revenue, | |
| ROW_NUMBER() OVER (PARTITION BY region ORDER BY revenue ASC) AS rn | |
| FROM sales_reps | |
| ) ranked | |
| WHERE rn = 1 | |
| ORDER BY region, name""", | |
| error_message=( | |
| "Query runs but returns wrong results: only 2 rows (one per region) " | |
| "with the LOWEST revenue instead of the HIGHEST. Expected 4 rows." | |
| ), | |
| hint="There are two bugs. Think about both the ranking function and the sort order.", | |
| test_cases=[ | |
| TestCase( | |
| description="All reps tied at rank 1 per region", | |
| expected_rows=[ | |
| {"name": "Alice", "region": "North", "revenue": 95000.0}, | |
| {"name": "Carol", "region": "North", "revenue": 95000.0}, | |
| {"name": "Dave", "region": "South", "revenue": 88000.0}, | |
| {"name": "Eve", "region": "South", "revenue": 88000.0}, | |
| ], | |
| order_by="region,name", | |
| ) | |
| ], | |
| solution_query="""\ | |
| SELECT name, region, revenue | |
| FROM ( | |
| SELECT name, region, revenue, | |
| RANK() OVER (PARTITION BY region ORDER BY revenue DESC) AS rk | |
| FROM sales_reps | |
| ) ranked | |
| WHERE rk = 1 | |
| ORDER BY region, name""", | |
| max_steps=6, | |
| ) | |
| _TASK_EXPERT_RECURSIVE = SQLTask( | |
| id="task_expert_recursive", | |
| level="expert", | |
| title="Traverse Org Chart with Recursive CTE", | |
| description="""\ | |
| TASK: The query below attempts to find all subordinates of the VP of Engineering | |
| (id=3), but it returns wrong results. Debug and fix it. | |
| SCHEMA: | |
| employees(id INTEGER, name VARCHAR, manager_id INTEGER) | |
| DATA (partial): | |
| CEO (id=1) | |
| VP Eng (id=3, reports to CEO) | |
| Lead A (id=5), Lead B (id=6) report to VP Eng | |
| Dev 1..4 (id=8..11) report to Leads | |
| Junior 1..2 (id=13..14) report to Dev 1 | |
| BROKEN QUERY: | |
| WITH direct AS ( | |
| SELECT id, name, manager_id FROM employees WHERE id = 3 | |
| ), | |
| level2 AS ( | |
| SELECT e.id, e.name, e.manager_id | |
| FROM employees e | |
| INNER JOIN direct d ON e.manager_id = d.id | |
| ) | |
| SELECT id, name, manager_id FROM direct | |
| UNION ALL | |
| SELECT id, name, manager_id FROM level2 | |
| ORDER BY id | |
| PROBLEM: | |
| The query returns some results but the row count and values don't match | |
| the expected output. Inspect what the anchor condition selects and whether | |
| the query reaches all depths of the org tree. | |
| GOAL: Return ALL 8 subordinates of VP Eng (id=3) at any depth. | |
| Do NOT include VP Eng himself β only his reports. | |
| Return id, name, manager_id columns, ordered by id ASC.""", | |
| schema_ddl="""\ | |
| CREATE TABLE employees (id INTEGER, name VARCHAR, manager_id INTEGER); | |
| INSERT INTO employees VALUES | |
| (1, 'CEO', NULL), | |
| (2, 'CFO', 1), | |
| (3, 'VP Eng', 1), | |
| (4, 'VP Sales', 1), | |
| (5, 'Lead A', 3), | |
| (6, 'Lead B', 3), | |
| (7, 'Sales Mgr',4), | |
| (8, 'Dev 1', 5), | |
| (9, 'Dev 2', 5), | |
| (10, 'Dev 3', 6), | |
| (11, 'Dev 4', 6), | |
| (12, 'Sales Rep',7), | |
| (13, 'Junior 1', 8), | |
| (14, 'Junior 2', 8); | |
| """, | |
| broken_query="""\ | |
| WITH direct AS ( | |
| SELECT id, name, manager_id FROM employees WHERE id = 3 | |
| ), | |
| level2 AS ( | |
| SELECT e.id, e.name, e.manager_id | |
| FROM employees e | |
| INNER JOIN direct d ON e.manager_id = d.id | |
| ) | |
| SELECT id, name, manager_id FROM direct | |
| UNION ALL | |
| SELECT id, name, manager_id FROM level2 | |
| ORDER BY id""", | |
| error_message=( | |
| "Query returns wrong results. Check carefully: does the anchor condition " | |
| "select the right starting rows? Does the query traverse all depths?" | |
| ), | |
| hint="There are multiple issues. Think about what the anchor selects and how deep the query reaches.", | |
| test_cases=[ | |
| TestCase( | |
| description="All 8 subordinates of VP Eng at any depth", | |
| expected_rows=[ | |
| {"id": 5, "name": "Lead A", "manager_id": 3}, | |
| {"id": 6, "name": "Lead B", "manager_id": 3}, | |
| {"id": 8, "name": "Dev 1", "manager_id": 5}, | |
| {"id": 9, "name": "Dev 2", "manager_id": 5}, | |
| {"id": 10, "name": "Dev 3", "manager_id": 6}, | |
| {"id": 11, "name": "Dev 4", "manager_id": 6}, | |
| {"id": 13, "name": "Junior 1", "manager_id": 8}, | |
| {"id": 14, "name": "Junior 2", "manager_id": 8}, | |
| ], | |
| order_by="id", | |
| ) | |
| ], | |
| solution_query="""\ | |
| WITH RECURSIVE subordinates AS ( | |
| SELECT id, name, manager_id | |
| FROM employees | |
| WHERE manager_id = 3 | |
| UNION ALL | |
| SELECT e.id, e.name, e.manager_id | |
| FROM employees e | |
| INNER JOIN subordinates s ON e.manager_id = s.id | |
| ) | |
| SELECT id, name, manager_id | |
| FROM subordinates | |
| ORDER BY id""", | |
| max_steps=7, | |
| ) | |
| _TASK_EXPERT_WINDOW = SQLTask( | |
| id="task_expert_window", | |
| level="expert", | |
| title="Fix Broken Window Functions: Running Total and Revenue Rank", | |
| description="""\ | |
| TASK: The query below computes a cumulative running total and a within-region | |
| revenue rank for each quarter, but the results are wrong. Debug and fix it. | |
| SCHEMA: | |
| quarterly_sales(region VARCHAR, quarter INTEGER, revenue DECIMAL) | |
| DATA: | |
| East: Q1=15000, Q2=18000, Q3=12000, Q4=20000 | |
| West: Q1=11000, Q2=14000, Q3=16000, Q4=16000 (note: Q3 and Q4 are tied) | |
| BROKEN QUERY: | |
| SELECT region, quarter, revenue, | |
| SUM(revenue) OVER (ORDER BY region, quarter) AS running_total, | |
| RANK() OVER (ORDER BY revenue DESC) AS revenue_rank | |
| FROM quarterly_sales | |
| ORDER BY region, quarter | |
| PROBLEM: | |
| The query returns wrong values for both running_total and revenue_rank. | |
| Compare your output against the expected results carefully. | |
| GOAL: running_total should be a cumulative sum per region (reset each region, | |
| ordered by quarter). revenue_rank should rank revenue within each region | |
| (ordered by revenue DESC), handling ties correctly (tied values must get | |
| the same rank). | |
| Final output: ORDER BY region ASC, quarter ASC.""", | |
| schema_ddl="""\ | |
| CREATE TABLE quarterly_sales (region VARCHAR, quarter INTEGER, revenue DECIMAL); | |
| INSERT INTO quarterly_sales VALUES | |
| ('East', 1, 15000), | |
| ('East', 2, 18000), | |
| ('East', 3, 12000), | |
| ('East', 4, 20000), | |
| ('West', 1, 11000), | |
| ('West', 2, 14000), | |
| ('West', 3, 16000), | |
| ('West', 4, 16000); | |
| """, | |
| broken_query="""\ | |
| SELECT region, quarter, revenue, | |
| SUM(revenue) OVER (ORDER BY region, quarter) AS running_total, | |
| RANK() OVER (ORDER BY revenue DESC) AS revenue_rank | |
| FROM quarterly_sales | |
| ORDER BY region, quarter""", | |
| error_message=( | |
| "Query runs but both computed columns are wrong. " | |
| "running_total does not reset per region. " | |
| "revenue_rank is a global ranking across all rows instead of per-region." | |
| ), | |
| hint="Multiple issues exist. Think about partitioning and how tied values should be ranked.", | |
| test_cases=[ | |
| TestCase( | |
| description="Per-region running total and within-region revenue rank with ties", | |
| expected_rows=[ | |
| {"region": "East", "quarter": 1, "revenue": 15000.0, "running_total": 15000.0, "revenue_rank": 3}, | |
| {"region": "East", "quarter": 2, "revenue": 18000.0, "running_total": 33000.0, "revenue_rank": 2}, | |
| {"region": "East", "quarter": 3, "revenue": 12000.0, "running_total": 45000.0, "revenue_rank": 4}, | |
| {"region": "East", "quarter": 4, "revenue": 20000.0, "running_total": 65000.0, "revenue_rank": 1}, | |
| {"region": "West", "quarter": 1, "revenue": 11000.0, "running_total": 11000.0, "revenue_rank": 4}, | |
| {"region": "West", "quarter": 2, "revenue": 14000.0, "running_total": 25000.0, "revenue_rank": 3}, | |
| {"region": "West", "quarter": 3, "revenue": 16000.0, "running_total": 41000.0, "revenue_rank": 1}, | |
| {"region": "West", "quarter": 4, "revenue": 16000.0, "running_total": 57000.0, "revenue_rank": 1}, | |
| ], | |
| order_by="region,quarter", | |
| ) | |
| ], | |
| solution_query="""\ | |
| SELECT region, quarter, revenue, | |
| SUM(revenue) OVER (PARTITION BY region ORDER BY quarter) AS running_total, | |
| RANK() OVER (PARTITION BY region ORDER BY revenue DESC) AS revenue_rank | |
| FROM quarterly_sales | |
| ORDER BY region, quarter""", | |
| max_steps=6, | |
| ) | |
| # ββ Task Registry βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TaskRegistry: | |
| """ | |
| Thread-safe registry of SQL tasks, shared across all environment sessions. | |
| Built-in tasks (easy / medium / hard) are always present and cannot be removed. | |
| Custom tasks can be added via register(), load_from_json(), or POST /tasks. | |
| """ | |
| _BUILTIN_IDS: frozenset = frozenset([ | |
| "task_easy_syntax", "task_medium_join", "task_hard_cte", | |
| "task_expert_rank", "task_expert_recursive", "task_expert_window", | |
| ]) | |
| def __init__(self, initial_tasks: List[SQLTask]) -> None: | |
| self._lock = Lock() | |
| # Insertion-ordered dict preserves cycling order | |
| self._tasks: Dict[str, SQLTask] = {t.id: t for t in initial_tasks} | |
| self._cycle_index: int = 0 | |
| # ββ CRUD βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def register(self, task: SQLTask) -> None: | |
| """Add or replace a task. Replaces silently if the ID already exists.""" | |
| with self._lock: | |
| self._tasks[task.id] = task | |
| def unregister(self, task_id: str) -> None: | |
| """ | |
| Remove a custom task. | |
| Raises ValueError for built-in tasks, KeyError if not found. | |
| """ | |
| if task_id in self._BUILTIN_IDS: | |
| raise ValueError(f"Built-in task '{task_id}' cannot be removed.") | |
| with self._lock: | |
| if task_id not in self._tasks: | |
| raise KeyError(task_id) | |
| del self._tasks[task_id] | |
| def get(self, task_id: str) -> SQLTask: | |
| """Return a task by ID. Raises KeyError with available IDs if not found.""" | |
| with self._lock: | |
| if task_id not in self._tasks: | |
| available = ", ".join(self._tasks.keys()) | |
| raise KeyError( | |
| f"Task '{task_id}' not found. " | |
| f"Available: {available}" | |
| ) | |
| return self._tasks[task_id] | |
| def list_all(self) -> List[SQLTask]: | |
| """Return all registered tasks in insertion order.""" | |
| with self._lock: | |
| return list(self._tasks.values()) | |
| def ids(self) -> List[str]: | |
| """Return all task IDs in insertion order.""" | |
| with self._lock: | |
| return list(self._tasks.keys()) | |
| # ββ Cycling βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def cycle_next(self) -> SQLTask: | |
| """Return the next task in round-robin order (wraps at end).""" | |
| with self._lock: | |
| tasks = list(self._tasks.values()) | |
| task = tasks[self._cycle_index % len(tasks)] | |
| self._cycle_index += 1 | |
| return task | |
| # ββ Bulk loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_from_json(self, path: str) -> int: | |
| """ | |
| Load tasks from a JSON file (list of task spec objects). | |
| Returns the number of tasks loaded. | |
| Minimal required fields per task: | |
| id, schema_ddl, expected_rows | |
| Example file:: | |
| [ | |
| { | |
| "id": "my_null_task", | |
| "level": "medium", | |
| "title": "Handle NULLs in aggregation", | |
| "schema_ddl": "CREATE TABLE ...; INSERT ...", | |
| "broken_query": "SELECT AVG(score) FROM ...", | |
| "expected_rows": [{"avg_score": 72.5}], | |
| "hint": "Use COALESCE to handle NULL scores." | |
| } | |
| ] | |
| """ | |
| raw = json.loads(Path(path).read_text()) | |
| if isinstance(raw, dict): | |
| raw = [raw] | |
| for item in raw: | |
| self.register(task_from_dict(item)) | |
| return len(raw) | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def __len__(self) -> int: | |
| with self._lock: | |
| return len(self._tasks) | |
| def __contains__(self, task_id: str) -> bool: | |
| with self._lock: | |
| return task_id in self._tasks | |
| # ββ Conversion helper βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def task_from_dict(d: Dict[str, Any]) -> SQLTask: | |
| """ | |
| Construct an SQLTask from a plain dict (JSON payload or loaded file). | |
| Required keys : id, schema_ddl, expected_rows | |
| Optional keys : level, title, description, broken_query, error_message, | |
| hint, order_by, solution_query, test_description, max_steps | |
| """ | |
| return SQLTask( | |
| id=d["id"], | |
| level=d.get("level", "custom"), | |
| title=d.get("title", d["id"]), | |
| description=d.get("description", ""), | |
| schema_ddl=d["schema_ddl"], | |
| broken_query=d.get("broken_query", ""), | |
| error_message=d.get("error_message", ""), | |
| hint=d.get("hint", ""), | |
| test_cases=[ | |
| TestCase( | |
| description=d.get("test_description", "Custom test case"), | |
| expected_rows=d["expected_rows"], | |
| order_by=d.get("order_by"), | |
| ) | |
| ], | |
| solution_query=d.get("solution_query", ""), | |
| max_steps=d.get("max_steps", 5), | |
| ) | |
| # ββ Global singleton ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| REGISTRY = TaskRegistry([ | |
| _TASK_EASY, _TASK_MEDIUM, _TASK_HARD, | |
| _TASK_EXPERT_RANK, _TASK_EXPERT_RECURSIVE, _TASK_EXPERT_WINDOW, | |
| ]) | |
| # Backwards-compat: snapshot of all built-in tasks at import time | |
| TASKS: List[SQLTask] = [ | |
| _TASK_EASY, _TASK_MEDIUM, _TASK_HARD, | |
| _TASK_EXPERT_RANK, _TASK_EXPERT_RECURSIVE, _TASK_EXPERT_WINDOW, | |
| ] | |
| TASK_BY_ID: Dict[str, SQLTask] = {t.id: t for t in TASKS} | |