Qingpeng Kong
Update public UI for Space
5d89bc6
# loader/data.py
from typing import List, Dict
from pathlib import Path
import csv
# -------------------------------------------------------------------
# 1. Mapping from dataset key → CSV filename
# -------------------------------------------------------------------
_DATASET_FILES: Dict[str, str] = {
"bar_exam": "BarExam_qa.csv",
"causal_judgment": "bbh_causal_judgement.csv",
"snarks": "bbh_snarks.csv",
"bbq_disamb": "BBQ_disamb.csv",
"cnn_dailymail": "CNN_dailymail.csv",
"drop": "drop.csv",
"esnli": "eSNLI.csv",
"fever": "fever.csv",
"hotpot_qa": "hotpot_qa.csv",
"medical_qa": "medical_qa.csv",
}
# -------------------------------------------------------------------
# 1b. Human-readable display names for UI
# -------------------------------------------------------------------
_DATASET_DISPLAY_NAMES: Dict[str, str] = {
"bar_exam": "Bar Exam Questions",
"causal_judgment": "Causal Judgment",
"snarks": "Snarks",
"bbq_disamb": "BBQ Disambiguation",
"cnn_dailymail": "CNN / DailyMail Summaries",
"drop": "DROP Reading Comprehension",
"esnli": "e-SNLI Natural Language Inference",
"fever": "FEVER Fact Checking",
"hotpot_qa": "HotpotQA Multi-hop Questions",
"medical_qa": "Medical Questions",
}
# -------------------------------------------------------------------
# 2. Where the CSVs live (loader/../datasets/)
# -------------------------------------------------------------------
def _datasets_dir() -> Path:
return (Path(__file__).resolve().parent.parent / "datasets").resolve()
# -------------------------------------------------------------------
# 3. Pick the first non-empty column among several candidates
# -------------------------------------------------------------------
def _pick_first_nonempty(raw: Dict[str, str], candidates: List[str]) -> str:
for c in candidates:
val = raw.get(c)
if val is not None and str(val).strip() != "":
return str(val)
return ""
# -------------------------------------------------------------------
# 4. Load a single CSV file and normalize it to our schema
# -------------------------------------------------------------------
def _load_one_dataset(name: str, filename: str) -> List[Dict[str, str]]:
"""
Reads a CSV file and converts each row to our standard format:
{
"id": "example_1",
"context": "...",
"prompt": "...",
"answer": "..." # optional
}
Only the first 10 rows are kept.
"""
path = _datasets_dir() / filename
rows: List[Dict[str, str]] = []
# errors="replace" avoids Unicode crashes for imperfect CSVs
try:
with path.open("r", encoding="utf-8", errors="replace", newline="") as f:
reader = csv.DictReader(f)
for i, raw in enumerate(reader, start=1):
ex_id = raw.get("id") or raw.get("example_id") \
or raw.get("uid") or f"example_{i}"
context = _pick_first_nonempty(raw, [
"Context", "context",
"passage", "article", "story", "premise",
"paragraph", "document", "sentence1", "sent1", "background",
])
prompt = _pick_first_nonempty(raw, [
"Prompt", "prompt",
"question", "input", "query",
"sentence2", "sent2", "hypothesis",
"qa_question", "title",
])
answer = _pick_first_nonempty(raw, [
"Answer", "answer",
"target", "gold", "label", "output", "reference",
"highlights",
])
ex = {
"id": str(ex_id),
"context": context,
"prompt": prompt,
}
if answer:
ex["answer"] = answer
rows.append(ex)
except FileNotFoundError:
return []
except Exception:
# Keep import resilient in constrained environments (e.g., Spaces).
return []
return rows[:10] # keep exactly 10 examples
# -------------------------------------------------------------------
# 5. Load all datasets ONCE when the module is imported
# -------------------------------------------------------------------
def _load_all_datasets() -> Dict[str, List[Dict[str, str]]]:
return {
name: _load_one_dataset(name, filename)
for name, filename in _DATASET_FILES.items()
}
_DATA: Dict[str, List[Dict[str, str]]] = _load_all_datasets() # ← cached
# -------------------------------------------------------------------
# 6. Public Functions — these are used by the app
# -------------------------------------------------------------------
def list_datasets() -> List[str]:
"""Return all dataset names, sorted alphabetically."""
return sorted(_DATA.keys())
def get_dataset_display_name(dataset: str) -> str:
"""Return human-readable display name for a dataset."""
return _DATASET_DISPLAY_NAMES.get(dataset, dataset)
def get_dataset_key_from_display_name(display_name: str) -> str:
"""Convert display name back to internal key."""
# Create reverse mapping
for key, name in _DATASET_DISPLAY_NAMES.items():
if name == display_name:
return key
# If not found, assume it's already a key
return display_name
def list_datasets_with_display_names() -> List[tuple[str, str]]:
"""Return list of (key, display_name) tuples, sorted by display name."""
pairs = [(key, _DATASET_DISPLAY_NAMES.get(key, key)) for key in _DATA.keys()]
return sorted(pairs, key=lambda x: x[1])
def list_dataset_display_names() -> List[str]:
"""Return list of display names only, sorted alphabetically."""
names = [_DATASET_DISPLAY_NAMES.get(key, key) for key in _DATA.keys()]
return sorted(names)
def get_examples(dataset: str, n: int = 10) -> List[Dict[str, str]]:
"""Return up to n examples for a dataset."""
if dataset not in _DATA:
raise KeyError(f"Unknown dataset: {dataset}")
return _DATA[dataset][:n]
def get_example_by_id(dataset: str, ex_id: str) -> Dict[str, str]:
"""Return a single example whose ID matches ex_id."""
if dataset not in _DATA:
raise KeyError(f"Unknown dataset: {dataset}")
for ex in _DATA[dataset]:
if ex["id"] == ex_id:
return ex
raise KeyError(f"Example id '{ex_id}' not found in dataset '{dataset}'")