File size: 4,645 Bytes
a8a3c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
039839b
a8a3c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Data models for the QueryForge SQL environment.

SQLAction    β€” the agent's submitted SQL query.
SQLObservation β€” task description + grading feedback returned after each step.
TaskSpec     β€” payload for registering a custom task via POST /tasks.
"""

from typing import Any, Dict, List, Optional

from openenv.core.env_server.types import Action, Observation
from pydantic import BaseModel, Field


class SQLAction(Action):
    """Action: submit a SQL query for evaluation."""

    sql: str = Field(..., description="The SQL query to submit for grading")


class SQLObservation(Observation):
    """Observation returned after reset() or step()."""

    # ── Task context ─────────────────────────────────────────────────────────
    task_id: str = Field(default="", description="Active task identifier")
    task_level: str = Field(
        default="", description="Difficulty: easy | medium | hard | expert"
    )
    task_title: str = Field(default="", description="Human-readable task title")
    task_description: str = Field(
        default="",
        description=(
            "Full task description: schema, broken query, error message, and goal"
        ),
    )

    # ── Per-step grading signals ──────────────────────────────────────────────
    syntax_valid: bool = Field(
        default=False, description="True if the submitted query parsed without error"
    )
    execution_success: bool = Field(
        default=False, description="True if the query ran to completion in DuckDB"
    )
    execution_error: Optional[str] = Field(
        default=None, description="Runtime error message, if any"
    )
    rows_returned: int = Field(
        default=0, description="Number of rows the query returned"
    )
    feedback: str = Field(
        default="",
        description="Detailed grading feedback from DuckDB + AI judge",
    )
    hint: str = Field(
        default="", description="Actionable hint for the next attempt"
    )

    # ── Episode progress ──────────────────────────────────────────────────────
    attempt: int = Field(
        default=0, description="Number of queries submitted this episode"
    )
    best_score: float = Field(
        default=0.0, description="Highest score achieved so far this episode"
    )


class TaskSpec(BaseModel):
    """
    Payload for registering a custom SQL task via POST /tasks
    or directly via REGISTRY.register(task_from_dict(spec.model_dump())).

    Required: id, schema_ddl, expected_rows
    Everything else has sensible defaults.
    """

    id: str = Field(
        ..., description="Unique task identifier, e.g. 'null_handling_task'"
    )
    level: str = Field(
        default="custom",
        description="Difficulty label: easy | medium | hard | custom",
    )
    title: str = Field(..., description="Human-readable task title")
    description: str = Field(
        default="",
        description="Full task description shown to the agent (schema, goal, etc.)",
    )
    schema_ddl: str = Field(
        ...,
        description="CREATE TABLE + INSERT statements to seed the DuckDB test DB",
    )
    broken_query: str = Field(
        default="",
        description="The broken or slow query the agent must fix",
    )
    error_message: str = Field(
        default="",
        description="Error or performance warning shown to the agent alongside the task",
    )
    hint: str = Field(
        default="",
        description="Actionable hint surfaced in the observation after each wrong attempt",
    )
    expected_rows: List[Dict[str, Any]] = Field(
        ...,
        description=(
            "Exact rows the correct query must return. "
            "Used for deterministic row-match scoring."
        ),
    )
    order_by: Optional[str] = Field(
        default=None,
        description="Comma-separated column names used to sort rows before comparison",
    )
    solution_query: str = Field(
        default="",
        description="Reference solution shown to the AI judge for quality scoring",
    )
    test_description: str = Field(
        default="Custom test case",
        description="One-line description of what the test case checks",
    )
    max_steps: int = Field(
        default=5, ge=1, le=20,
        description="Maximum number of step() calls allowed per episode",
    )