wenjiao's picture
refactor repo code
b15b21e
import json
import logging
import os
import re
from datetime import datetime, timezone
import pandas as pd
from src.display.formatting import make_clickable_model
from src.display.utils import eval_queue_cols
from src.leaderboard.read_auto_pipeline_results import get_auto_pipeline_results_df as _get_auto_pipeline_results_df
from src.queue_eta import format_eta
logger = logging.getLogger(__name__)
def _normalize_queue_entry(data: dict) -> dict:
"""Normalize queue JSON keys to match EvalQueueColumn/QuantQueueColumn field names.
Handles mismatches between:
- JSON 'weight_dtype' → column 'weight_type'
- Quant entries with 'quant_precision' but no 'precision'
"""
if "weight_type" not in data and "weight_dtype" in data:
data["weight_type"] = data["weight_dtype"]
if "precision" not in data and "quant_precision" in data:
data["precision"] = data["quant_precision"]
return data
# ── Status inference from results (for quant requests without write-back) ──
def _normalize_scheme(scheme: str) -> str:
"""Normalize 'INT4 (W4A16)' and 'W4A16' to the same canonical key."""
m = re.search(r"\(([^)]+)\)", scheme)
return (m.group(1) if m else scheme).strip().upper()
def _has_zero_accuracy_task(accuracy: dict) -> bool:
"""Return True if any task in accuracy has acc == 0 (evaluation failure)."""
tasks = accuracy.get("tasks")
if not isinstance(tasks, dict):
return False
for task_val in tasks.values():
acc_value = task_val if not isinstance(task_val, dict) else task_val.get("accuracy")
try:
if acc_value is not None and float(acc_value) == 0.0:
return True
except (TypeError, ValueError):
pass
return False
def _derive_status_from_aggregate(agg: dict) -> str:
"""Derive a pipeline status from an auto_pipeline aggregate result dict."""
qs = agg.get("quant_summary") or {}
acc = agg.get("accuracy") or {}
qs_status = qs.get("status", "missing")
acc_status = acc.get("status", "missing")
if qs_status == "failed":
return "Quant Failed"
if acc_status == "failed":
return "Eval Failed"
if _has_zero_accuracy_task(acc):
return "Eval Failed"
if qs_status == "success" and acc_status == "success":
return "Finished"
if acc_status == "success":
return "Finished"
if qs_status == "success":
return "Quantized"
return "Partial"
def _build_result_index(results_path: str) -> dict:
"""Scan results/ for auto_pipeline aggregate files, build a lookup index.
Returns: ``{(model_id_lower, scheme_normalized): [aggregate_dict, ...]}``
where the list is sorted by ``generated_at`` ascending. Multiple
aggregates per key are preserved so that a re-submission of a previously
failed (model, scheme) does not cause the older queue entry to inherit the
newer run's status.
"""
index: dict = {}
for root, _, files in os.walk(results_path):
for f in files:
if not (f.startswith("results_") and f.endswith(".json")):
continue
fp = os.path.join(root, f)
try:
with open(fp) as fh:
data = json.load(fh)
except (json.JSONDecodeError, OSError):
continue
# Only auto_pipeline aggregate files have these keys
if "run_dir" not in data or "copied_files" not in data:
continue
model_id = data.get("model_id", "")
qs = data.get("quant_summary") or {}
scheme = _normalize_scheme(qs.get("scheme", ""))
key = (model_id.lower().strip(), scheme)
index.setdefault(key, []).append(data)
for key, aggs in index.items():
aggs.sort(key=lambda a: a.get("generated_at", ""))
return index
def _pick_aggregate_in_window(
aggs: list,
submitted_time: str,
next_submitted_time: str | None,
) -> dict | None:
"""Return the latest aggregate generated within ``[submitted, next_submitted)``.
The window represents the lifetime of a single queue entry: the entry was
submitted at ``submitted_time``, and any subsequent re-submission for the
same (model, scheme) starts a new window at ``next_submitted_time``.
Aggregates outside the window do not belong to this entry.
"""
if not aggs:
return None
chosen: dict | None = None
for agg in aggs:
gen = agg.get("generated_at", "")
if submitted_time and gen <= submitted_time:
continue
if next_submitted_time and gen >= next_submitted_time:
break
chosen = agg
return chosen
def _parse_iso_utc(value: str) -> datetime | None:
"""Parse the queue/result UTC timestamp format used in JSON files."""
if not value:
return None
try:
parsed = datetime.fromisoformat(str(value).replace("Z", "+00:00"))
except ValueError:
return None
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed.astimezone(timezone.utc)
def _result_is_newer_than_submission(entry: dict, agg: dict) -> bool:
"""Return True only when an aggregate result was generated after submission.
A failed model can be re-submitted with the same model/scheme key. In that
case old aggregate result files still match the new request, but they must
not override the fresh ``Pending`` status in ``status/``.
"""
submitted_at = _parse_iso_utc(entry.get("submitted_time", ""))
generated_at = _parse_iso_utc(agg.get("generated_at", ""))
if submitted_at is None or generated_at is None:
# Keep the previous behaviour if one side has no comparable timestamp.
return True
return generated_at > submitted_at
def _infer_quant_status(entry: dict, result_index: dict) -> str:
"""Infer the true status of a quant request from results/ aggregates.
The caller is expected to have populated ``entry["_matched_aggregate"]``
with the aggregate that belongs to *this* submission window (see
``_attach_matched_aggregates``). ``result_index`` is kept in the signature
for backward compatibility but no longer used directly.
Falls back to the original status field if no matching result is found.
"""
original_status = entry.get("status", "Pending")
# Already a terminal status — trust it. ``Failed`` is the value written by
# the CI dispatcher when retries are exhausted; map it to ``Quant Failed``
# so the UI groups it with the rest of the failed quantizations instead of
# dropping it entirely.
if original_status in ("Finished", "Quant Failed", "Eval Failed"):
return original_status
if original_status == "Failed":
return "Quant Failed"
agg = entry.get("_matched_aggregate")
if agg is None:
return original_status
return _derive_status_from_aggregate(agg)
def _infer_eval_status(entry: dict, result_index: dict) -> str:
"""Infer the true status of an eval request from results/ aggregates.
Uses the per-entry aggregate stored in ``entry["_matched_aggregate"]``
(populated by ``_attach_matched_aggregates``) so that an older failed
submission is not overwritten by a newer success for the same model.
"""
original_status = entry.get("status", "Pending")
if original_status in ("Finished", "Eval Failed"):
return original_status
if original_status == "Failed":
return "Eval Failed"
agg = entry.get("_matched_aggregate")
if agg is None:
return original_status
acc = agg.get("accuracy") or {}
acc_status = acc.get("status", "missing")
if acc_status == "failed" or _has_zero_accuracy_task(acc):
return "Eval Failed"
if acc_status == "success":
return "Finished"
return original_status
def _quant_match_key(entry: dict) -> tuple[str, str]:
model = entry.get("_model_id", entry.get("model", ""))
scheme = _normalize_scheme(entry.get("quant_scheme", ""))
return (model.lower().strip(), scheme)
def _eval_match_keys(entry: dict, result_index: dict) -> list[tuple[str, str]]:
"""Return all aggregate keys whose model matches this eval entry.
Eval entries are submitted against an already-quantized model and don't
carry a separate ``quant_scheme`` field for matching, so we collect every
indexed key whose model_id matches.
"""
model = entry.get("_model_id", entry.get("model", "")).lower().strip()
return [k for k in result_index if k[0] == model]
def _attach_matched_aggregates(
entries: list[dict],
result_index: dict,
request_type: str | None,
) -> None:
"""Attach the per-entry aggregate to ``entry["_matched_aggregate"]``.
Entries are grouped by their inference key and sorted by ``submitted_time``
so that aggregates can be windowed: the i-th submission for a given key
consumes only aggregates generated between its ``submitted_time`` and the
(i+1)-th submission's ``submitted_time``. This prevents an older failed
queue entry from inheriting a newer successful run's status when a model
is re-submitted.
"""
# Build groups by key
groups: dict = {}
for entry in entries:
if request_type == "quant":
keys = [_quant_match_key(entry)]
elif request_type == "eval":
keys = _eval_match_keys(entry, result_index)
else:
keys = [_quant_match_key(entry)]
# Use a single representative key per entry so that windowing is
# well-defined. For eval we may have multiple matching keys; in that
# case window each group independently.
for key in keys or [None]:
groups.setdefault(key, []).append(entry)
for key, group_entries in groups.items():
if key is None:
continue
aggs = result_index.get(key, [])
if not aggs:
continue
ordered = sorted(group_entries, key=lambda e: e.get("submitted_time", ""))
for i, entry in enumerate(ordered):
submitted = entry.get("submitted_time", "")
next_submitted = (
ordered[i + 1].get("submitted_time", "") if i + 1 < len(ordered) else None
)
agg = _pick_aggregate_in_window(aggs, submitted, next_submitted)
if agg is not None:
# First match wins — for eval entries that map to multiple
# keys, keep the earliest generated aggregate.
existing = entry.get("_matched_aggregate")
if existing is None or agg.get("generated_at", "") < existing.get("generated_at", ""):
entry["_matched_aggregate"] = agg
def _load_queue_entries(save_path: str, request_type: str = None) -> list[dict]:
"""Load all queue JSON entries from *save_path* (including subdirectories).
Args:
save_path: Directory containing queue JSON files.
request_type: Optional filter — ``"eval"`` keeps only ``_eval_request_``
files, ``"quant"`` keeps only ``_quant_request_`` files.
``None`` keeps all files (backward-compatible).
Returns:
List of loaded & normalized dicts.
"""
if not os.path.isdir(save_path):
return []
entries = [e for e in os.listdir(save_path) if not e.startswith(".")]
all_evals = []
def _process_file(file_path: str):
fname = os.path.basename(file_path)
if request_type == "eval" and "_quant_request_" in fname:
return
if request_type == "quant" and "_eval_request_" in fname:
return
try:
with open(file_path) as fp:
data = json.load(fp)
except (json.JSONDecodeError, OSError) as e:
logger.warning("Skipping malformed queue file %s: %s", file_path, e)
return
if "model" not in data:
logger.warning("Skipping queue file without 'model' key: %s", file_path)
return
data["_model_id"] = data["model"] # preserve raw ID before HTML transform
data[eval_queue_cols.model.name] = make_clickable_model(data["model"])
data[eval_queue_cols.revision.name] = data.get("revision", "main")
_normalize_queue_entry(data)
all_evals.append(data)
for entry in entries:
full_path = os.path.join(save_path, entry)
if entry.endswith(".json"):
_process_file(full_path)
elif os.path.isdir(full_path):
sub_entries = [e for e in os.listdir(full_path) if not e.startswith(".")]
for sub_entry in sub_entries:
if sub_entry.endswith(".json"):
_process_file(os.path.join(full_path, sub_entry))
return all_evals
def _split_by_status(all_evals: list[dict]) -> tuple[list, list, list, list]:
"""Split entries into (pending, running, finished, failed) lists by status."""
pending = [e for e in all_evals if e.get("status") in ("Pending", "Rerun", "Waiting", "Quantized")]
running = [e for e in all_evals if e.get("status") in ("Running", "Triggered")]
finished = [e for e in all_evals
if e.get("status", "").startswith("Finished")
or e.get("status") == "PENDING_NEW_EVAL"]
# ``Failed`` is written by the CI dispatcher when retries are exhausted.
# Include it (and its case variants) here so retry-exhausted entries are
# not silently dropped from the UI.
failed = [e for e in all_evals
if e.get("status") in ("Quant Failed", "Eval Failed", "Partial",
"Failed", "failed")]
return pending, running, finished, failed
def _build_queue_dfs(pending: list, running: list, finished: list, failed: list,
cols: list) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Build DataFrames from split lists, keeping only the requested *cols*."""
# Determine which requested cols actually exist across ALL records so we
# never drop a column just because the first entry happens to be old/missing it.
all_records = pending + running + finished + failed
if all_records:
present_keys: set[str] = set()
for rec in all_records:
present_keys.update(rec.keys())
existing_cols = [c for c in cols if c in present_keys]
else:
existing_cols = cols
def _to_df(records):
if not records:
return pd.DataFrame(columns=existing_cols)
return pd.DataFrame.from_records(records, columns=existing_cols)
df_finished = _to_df(finished)
df_running = _to_df(running)
df_pending = _to_df(pending)
df_failed = _to_df(failed)
return df_finished, df_running, df_pending, df_failed
def _inject_eta(pending: list[dict], running: list[dict], concurrency: int = 2):
"""Add an ``eta`` field to each pending entry using slot-simulation ETA.
Formula: ETA = running_remaining + ⌈queue_pos / concurrency⌉ × task_hours
Modifies *pending* in-place (sorted by submitted_time first).
Non-auto entries get an empty ETA string.
"""
import math
from src.queue_eta import estimate_task_hours, format_eta, _get_params
# Sort pending by submitted_time (oldest first) for correct queue position
pending.sort(key=lambda e: e.get("submitted_time", ""))
# Estimate average remaining time for running entries
if running:
active_hours = [estimate_task_hours(_get_params(e)) for e in running]
running_remaining = sum(active_hours) / len(active_hours)
else:
running_remaining = 0.0
auto_pos = 0 # counter for auto_quant/auto_eval entries only
for entry in pending:
script = entry.get("script", "")
if script in ("auto_quant", "auto_eval"):
auto_pos += 1
task_hours = estimate_task_hours(_get_params(entry))
eta = running_remaining + math.ceil(auto_pos / concurrency) * task_hours
entry["eta"] = format_eta(eta)
else:
entry["eta"] = ""
def get_evaluation_queue_df(save_path: str, cols: list,
request_type: str = None,
results_path: str = None) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Load evaluation queue and split into (finished, running, pending, failed) DataFrames.
Args:
save_path: Directory containing queue JSON files.
cols: Column names to keep in the output DataFrames.
request_type: ``"eval"``, ``"quant"``, or ``None`` (all).
results_path: Path to results directory. When provided the status
of each entry is cross-referenced against auto_pipeline aggregate
results so that completed jobs whose status was never written back
are correctly classified.
"""
all_evals = _load_queue_entries(save_path, request_type=request_type)
# Infer true status from results for entries without write-back
if results_path:
result_index = _build_result_index(results_path)
_attach_matched_aggregates(all_evals, result_index, request_type)
for entry in all_evals:
if request_type == "quant":
entry["status"] = _infer_quant_status(entry, result_index)
elif request_type == "eval":
entry["status"] = _infer_eval_status(entry, result_index)
pending, running, finished, failed = _split_by_status(all_evals)
# Inject ETA for pending entries
_inject_eta(pending, running, concurrency=2)
# Ensure eta column exists in non-pending lists (empty for them)
for entry in running + finished + failed:
entry.setdefault("eta", "")
return _build_queue_dfs(pending, running, finished, failed, cols)
def get_auto_pipeline_results_df(results_path: str) -> pd.DataFrame:
return _get_auto_pipeline_results_df(results_path)