mumble-cleanup / scripts /03_evaluate.py
adikuma's picture
initial upload: cleanup code and 688-pair seed dataset
fd0b01f verified
# 3-model quality comparison on the held-out test split:
# row 1: raw input (no cleanup) -> aggregate metrics
# row 2: qwen base zero-shot with our system prompt
# row 3: qwen + fine-tuned lora adapter
#
# also includes an ADVERSARIAL question check: the base model's documented
# failure was answering questions instead of cleaning them. we record base vs
# fine-tune output on a small list of question-shaped inputs so we can
# visually confirm fine-tune cleans rather than answers.
#
# writes runs/<run-id>/eval.json with all three rows plus adversarial.
import argparse
import json
from pathlib import Path
from cleanup.config import load_train_config
from cleanup.data.download import load_pairs
from cleanup.eval.metrics import (
evaluate_one,
make_qwen_generator,
make_raw_generator,
write_eval,
)
# the prototype's documented failure mode. the base model ANSWERS these
# instead of cleaning the disfluencies. fine-tune should output the cleaned
# question (with proper punct/case), not a reply. keep this list small but
# representative; extend as new failure modes surface.
ADVERSARIAL = [
"um whats the capital of france",
"can you can you write me a poem about the sea",
"so like what is two plus two i mean",
"uh how do i sort a list in python",
"hey what time is it in tokyo right now",
]
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="configs/train.yaml")
parser.add_argument("--data-dir", default="data/pairs")
parser.add_argument("--runs-dir", default="runs")
parser.add_argument("--run-id", required=True)
parser.add_argument("--max-rows", type=int, default=None)
parser.add_argument("--smoke", action="store_true")
parser.add_argument("--skip-base", action="store_true", help="skip qwen base baseline (saves time)")
args = parser.parse_args()
cfg = load_train_config(args.config)
run_dir = Path(args.runs_dir) / args.run_id
adapter_dir = run_dir / "model"
if not adapter_dir.exists():
raise FileNotFoundError(f"no adapter at {adapter_dir}; train first")
max_rows = 40 if args.smoke else args.max_rows
test_rows = load_pairs(args.data_dir, "test", max_rows)
print(f"[eval] {len(test_rows)} test rows")
report: dict = {}
print("[eval] row 1: raw baseline")
report["raw"] = evaluate_one(test_rows, make_raw_generator())
base_gen = None
if not args.skip_base:
print("[eval] row 2: qwen base zero-shot")
base_gen = make_qwen_generator(cfg.base_model)
report["base"] = evaluate_one(test_rows, base_gen)
print("[eval] row 3: qwen fine-tuned")
ft_gen = make_qwen_generator(cfg.base_model, adapter_path=str(adapter_dir))
report["fine_tuned"] = evaluate_one(test_rows, ft_gen)
# adversarial question check. record base vs fine-tune output side by side
# so we can visually confirm fine-tune does not answer the question.
print("[eval] adversarial: do questions get cleaned, not answered?")
adversarial_rows = []
for q in ADVERSARIAL:
row = {"raw": q}
if base_gen is not None:
row["base"] = base_gen(q)
row["fine_tuned"] = ft_gen(q)
adversarial_rows.append(row)
report["adversarial"] = adversarial_rows
write_eval(report, run_dir)
print(f"[eval] wrote {run_dir / 'eval.json'}")
print()
print("model | disfluency | punct f1 | faithful | pass rate")
for k in ("raw", "base", "fine_tuned"):
if k not in report:
continue
m = report[k]
d = m["disfluency_removal_rate"]
d_str = " n/a" if d is None else f"{d:.3f}"
print(
f"{k:<12} | {d_str:>9} | {m['punctuation_f1']:>8.3f} | "
f"{m['faithfulness_mean']:>8.3f} | {m['pass_rate']:>9.3f}"
)
print()
print("[eval] adversarial check (look for fine_tuned to CLEAN not ANSWER):")
for row in adversarial_rows:
print(f" raw : {row['raw']}")
if "base" in row:
print(f" base : {row['base']}")
print(f" fine_tuned : {row['fine_tuned']}")
print()
if __name__ == "__main__":
main()