Spaces:
Sleeping
Sleeping
File size: 7,658 Bytes
5523185 5b324a8 5523185 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | """
Evaluation / Benchmark Script — DataCleanEnv
=============================================
Run any model across all tasks and report scores.
Usage:
# Start the environment server first:
uvicorn server.app:app --host 0.0.0.0 --port 8000
# Evaluate with default settings:
python eval.py
# Evaluate a specific model:
python eval.py --model "meta-llama/Llama-3.1-8B-Instruct"
# Evaluate with seed variation (multiple runs):
python eval.py --seeds 5 --tasks customer_contacts sales_records
# JSON output for CI/programmatic use:
python eval.py --json
Environment variables:
API_BASE_URL LLM API endpoint
MODEL_NAME Model identifier
HF_TOKEN API key
ENV_URL Environment server URL
"""
import argparse
import json
import os
import re
import statistics
import sys
import textwrap
from typing import Any, Dict, List, Optional
import requests
from openai import OpenAI
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN", "")
MODEL_NAME = os.getenv("MODEL_NAME", "")
ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
ALL_TASKS = ["customer_contacts", "sales_records", "employee_records", "financial_transactions"]
PLANNING_PROMPT = textwrap.dedent("""\
You are an expert data quality analyst. Analyze the dataset and produce
a COMPLETE fix plan as a JSON array. Output ONLY the JSON array.
Format: [{"action": "fix", "row": N, "column": "col", "value": "val"}, {"action": "delete", "row": N}, ...]
Rules: Emails must be user@domain.tld. Dates must be YYYY-MM-DD. Numbers must be positive.
Use exact canonical forms from the task description. Delete duplicates (highest index first).
List fixes first, then deletes. Only fix cells with actual issues.
""")
def env_reset(task_id: str, seed: Optional[int] = None) -> Dict[str, Any]:
payload: Dict[str, Any] = {"task_id": task_id}
if seed is not None:
payload["seed"] = seed
resp = requests.post(f"{ENV_URL}/reset", json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
return data.get("observation", data)
def env_step(command: str) -> Dict[str, Any]:
resp = requests.post(f"{ENV_URL}/step", json={"action": {"command": command}}, timeout=30)
resp.raise_for_status()
data = resp.json()
return data.get("observation", data)
def extract_json_plan(text: str) -> Optional[List[Dict]]:
text = re.sub(r"^```(?:json)?\s*\n?", "", text.strip())
text = re.sub(r"\n?```\s*$", "", text.strip())
try:
plan = json.loads(text)
if isinstance(plan, list):
return plan
except json.JSONDecodeError:
pass
match = re.search(r"\[[\s\S]*\]", text)
if match:
try:
plan = json.loads(match.group())
if isinstance(plan, list):
return plan
except json.JSONDecodeError:
pass
return None
def run_task(client: OpenAI, model: str, task_id: str, seed: Optional[int] = None) -> float:
"""Run a single task and return the score."""
obs = env_reset(task_id, seed=seed)
if obs.get("done", False):
return obs.get("current_score", 0.0)
# Phase 1: Inspect all columns
columns = []
for line in obs.get("column_info", "").strip().splitlines():
if ":" in line:
col = line.strip().split(":")[0].strip()
if col:
columns.append(col)
inspections = []
for col in columns:
obs = env_step(f'inspect("{col}")')
if obs.get("done", False):
return obs.get("current_score", 0.0)
inspections.append(f"[{col}]: {obs.get('feedback', '')}")
# Phase 2: Plan
context = (
f"Task: {obs.get('task_description', '')}\n"
f"Columns:\n{obs.get('column_info', '')}\n"
f"Data:\n{obs.get('data_preview', '')}\n\n"
f"Inspections:\n" + "\n\n".join(inspections) + "\n\n"
f"Remaining steps: {obs.get('actions_remaining', 0)}. Issues: {obs.get('total_issues', 0)}."
)
try:
completion = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": PLANNING_PROMPT},
{"role": "user", "content": context},
],
temperature=0.0,
max_tokens=2000,
)
plan = extract_json_plan(completion.choices[0].message.content or "")
except Exception:
plan = None
# Phase 3: Execute
if plan:
for action in plan:
if obs.get("done", False) or obs.get("actions_remaining", 0) <= 1:
break
act_type = action.get("action", "")
if act_type == "fix":
cmd = f'fix({action["row"]}, "{action["column"]}", "{action["value"]}")'
elif act_type == "delete":
cmd = f'delete({action["row"]})'
else:
continue
obs = env_step(cmd)
# Submit
if not obs.get("done", False):
obs = env_step("submit()")
return obs.get("current_score", 0.0)
def main():
parser = argparse.ArgumentParser(description="Benchmark models on DataCleanEnv")
parser.add_argument("--model", default=MODEL_NAME or "meta-llama/Llama-3.1-8B-Instruct")
parser.add_argument("--tasks", nargs="*", default=ALL_TASKS)
parser.add_argument("--seeds", type=int, default=1, help="Number of seeds per task (1 = no seed)")
parser.add_argument("--env-url", default=ENV_URL)
parser.add_argument("--json", action="store_true", help="Output JSON")
args = parser.parse_args()
global ENV_URL
ENV_URL = args.env_url
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
results: Dict[str, List[float]] = {}
for task_id in args.tasks:
scores = []
seeds = [None] if args.seeds <= 1 else list(range(1, args.seeds + 1))
for seed in seeds:
seed_str = f" (seed={seed})" if seed else ""
if not args.json:
print(f" Running {task_id}{seed_str}...", end=" ", flush=True)
score = run_task(client, args.model, task_id, seed=seed)
scores.append(score)
if not args.json:
print(f"{score:.4f}")
results[task_id] = scores
if args.json:
report = {
"model": args.model,
"env_url": args.env_url,
"results": {
task: {"scores": scores, "mean": statistics.mean(scores),
"stdev": statistics.stdev(scores) if len(scores) > 1 else 0.0}
for task, scores in results.items()
},
"average": statistics.mean(s for scores in results.values() for s in scores),
}
print(json.dumps(report, indent=2))
else:
print(f"\n{'='*60}")
print(f"BENCHMARK RESULTS — {args.model}")
print(f"{'='*60}")
all_scores = []
for task_id, scores in results.items():
mean = statistics.mean(scores)
all_scores.extend(scores)
if len(scores) > 1:
sd = statistics.stdev(scores)
print(f" {task_id:30s} {mean:.4f} ± {sd:.4f} (n={len(scores)})")
else:
print(f" {task_id:30s} {mean:.4f}")
print(f" {'AVERAGE':30s} {statistics.mean(all_scores):.4f}")
if __name__ == "__main__":
main()
|