Spaces:
Sleeping
Sleeping
File size: 4,645 Bytes
a8a3c90 039839b a8a3c90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | """
Data models for the QueryForge SQL environment.
SQLAction β the agent's submitted SQL query.
SQLObservation β task description + grading feedback returned after each step.
TaskSpec β payload for registering a custom task via POST /tasks.
"""
from typing import Any, Dict, List, Optional
from openenv.core.env_server.types import Action, Observation
from pydantic import BaseModel, Field
class SQLAction(Action):
"""Action: submit a SQL query for evaluation."""
sql: str = Field(..., description="The SQL query to submit for grading")
class SQLObservation(Observation):
"""Observation returned after reset() or step()."""
# ββ Task context βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
task_id: str = Field(default="", description="Active task identifier")
task_level: str = Field(
default="", description="Difficulty: easy | medium | hard | expert"
)
task_title: str = Field(default="", description="Human-readable task title")
task_description: str = Field(
default="",
description=(
"Full task description: schema, broken query, error message, and goal"
),
)
# ββ Per-step grading signals ββββββββββββββββββββββββββββββββββββββββββββββ
syntax_valid: bool = Field(
default=False, description="True if the submitted query parsed without error"
)
execution_success: bool = Field(
default=False, description="True if the query ran to completion in DuckDB"
)
execution_error: Optional[str] = Field(
default=None, description="Runtime error message, if any"
)
rows_returned: int = Field(
default=0, description="Number of rows the query returned"
)
feedback: str = Field(
default="",
description="Detailed grading feedback from DuckDB + AI judge",
)
hint: str = Field(
default="", description="Actionable hint for the next attempt"
)
# ββ Episode progress ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
attempt: int = Field(
default=0, description="Number of queries submitted this episode"
)
best_score: float = Field(
default=0.0, description="Highest score achieved so far this episode"
)
class TaskSpec(BaseModel):
"""
Payload for registering a custom SQL task via POST /tasks
or directly via REGISTRY.register(task_from_dict(spec.model_dump())).
Required: id, schema_ddl, expected_rows
Everything else has sensible defaults.
"""
id: str = Field(
..., description="Unique task identifier, e.g. 'null_handling_task'"
)
level: str = Field(
default="custom",
description="Difficulty label: easy | medium | hard | custom",
)
title: str = Field(..., description="Human-readable task title")
description: str = Field(
default="",
description="Full task description shown to the agent (schema, goal, etc.)",
)
schema_ddl: str = Field(
...,
description="CREATE TABLE + INSERT statements to seed the DuckDB test DB",
)
broken_query: str = Field(
default="",
description="The broken or slow query the agent must fix",
)
error_message: str = Field(
default="",
description="Error or performance warning shown to the agent alongside the task",
)
hint: str = Field(
default="",
description="Actionable hint surfaced in the observation after each wrong attempt",
)
expected_rows: List[Dict[str, Any]] = Field(
...,
description=(
"Exact rows the correct query must return. "
"Used for deterministic row-match scoring."
),
)
order_by: Optional[str] = Field(
default=None,
description="Comma-separated column names used to sort rows before comparison",
)
solution_query: str = Field(
default="",
description="Reference solution shown to the AI judge for quality scoring",
)
test_description: str = Field(
default="Custom test case",
description="One-line description of what the test case checks",
)
max_steps: int = Field(
default=5, ge=1, le=20,
description="Maximum number of step() calls allowed per episode",
)
|