BioRLHF / scripts /evaluate_grpo.py
jang1563's picture
Phase 4: V1-aware calibration verifier, eval tools, cleanup
2145d80
#!/usr/bin/env python3
"""
BioGRPO Post-Training Evaluation Script
Evaluates a GRPO-trained model against:
1. Held-out GeneLab questions (LOMO: Leave-One-Mission-Out)
2. Calibration metrics (ECE, Brier, overconfidence rate)
3. Per-verifier reward scores
4. Baseline comparison (SFT, DPO)
Usage:
python scripts/evaluate_grpo.py \
--model ./biogrpo_mve_model \
--sft-baseline ./kmp_sft_model_final \
--hold-out-tissues eye \
--output results/grpo_mve_eval.json
"""
import argparse
import json
import torch
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from tqdm import tqdm
from biorlhf.data.grpo_dataset import build_grpo_dataset, get_dataset_stats
from biorlhf.verifiers.composer import VerifierComposer
from biorlhf.verifiers.uncertainty import _extract_confidence_simple, SimpleConfidence
from biorlhf.evaluation.calibration import compute_calibration_metrics
def load_model(
model_path: str,
base_model: str = "mistralai/Mistral-7B-v0.3",
use_4bit: bool = True,
sft_adapter: Optional[str] = None,
):
"""Load a fine-tuned model with LoRA adapters.
For GRPO checkpoints trained on an SFT-merged base, pass sft_adapter
to first merge the SFT adapter before applying the GRPO adapter.
"""
print(f" Base model: {base_model}")
if sft_adapter:
print(f" SFT adapter (merge first): {sft_adapter}")
print(f" Adapter: {model_path}")
bnb_config = None
if use_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
# If GRPO was trained on SFT-merged base, merge SFT first
if sft_adapter:
print(" Merging SFT adapter...")
model = PeftModel.from_pretrained(model, sft_adapter)
model = model.merge_and_unload()
model = PeftModel.from_pretrained(model, model_path)
# Always load tokenizer from base model (adapter dirs lack config.json)
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def generate_response(
model,
tokenizer,
prompt: str,
max_new_tokens: int = 512,
temperature: float = 0.1,
) -> str:
"""Generate a response from the model."""
formatted = f"### Instruction:\n{prompt}\n\n### Response:\n"
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=0.9,
do_sample=temperature > 0,
pad_token_id=tokenizer.pad_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "### Response:" in response:
response = response.split("### Response:")[-1].strip()
return response
def evaluate_with_verifiers(
model,
tokenizer,
eval_dataset,
composer: VerifierComposer,
max_samples: Optional[int] = None,
) -> Dict:
"""Evaluate model using the verifier stack.
Returns per-sample results and aggregated metrics.
"""
results = []
n = len(eval_dataset)
if max_samples:
n = min(n, max_samples)
for i in tqdm(range(n), desc="Evaluating"):
sample = eval_dataset[i]
prompt = sample["prompt"]
gt = sample["ground_truth"]
qtype = sample["question_type"]
applicable = sample["applicable_verifiers"]
response = generate_response(model, tokenizer, prompt)
reward = composer.compute_reward(
prompt=prompt,
completion=response,
ground_truth=gt,
question_type=qtype,
applicable_verifiers=applicable,
)
# Extract confidence for calibration (match V4's extraction method)
try:
from bioeval.scoring.calibration import extract_confidence
conf_extraction = extract_confidence(response)
conf = SimpleConfidence(
stated=conf_extraction.stated_confidence or "medium",
numeric=conf_extraction.confidence_score,
source="bioeval",
)
except ImportError:
conf = _extract_confidence_simple(response)
results.append({
"prompt": prompt[:100],
"response": response[:300],
"total_reward": reward.total_reward,
"verifier_scores": reward.verifier_scores,
"question_type": qtype,
"source": sample.get("source", "unknown"),
"tissue": sample.get("tissue", "unknown"),
"confidence": conf.numeric,
"confidence_stated": conf.stated,
})
# Aggregate metrics
total_rewards = [r["total_reward"] for r in results]
per_verifier: Dict[str, List[float]] = {}
for r in results:
for v, s in r["verifier_scores"].items():
per_verifier.setdefault(v, []).append(s)
verifier_means = {v: sum(s) / len(s) for v, s in per_verifier.items()}
# Per question type
by_type: Dict[str, List[float]] = {}
for r in results:
by_type.setdefault(r["question_type"], []).append(r["total_reward"])
type_means = {t: sum(s) / len(s) for t, s in by_type.items()}
return {
"n_samples": len(results),
"mean_reward": sum(total_rewards) / len(total_rewards) if total_rewards else 0,
"verifier_means": verifier_means,
"by_question_type": type_means,
"per_sample": results,
}
def evaluate_calibration(results: List[Dict]) -> Dict:
"""Compute calibration metrics from evaluation results."""
confidences = [r["confidence"] for r in results]
# Correctness: reward > 0.5 considered "correct"
correctnesses = [r["total_reward"] > 0.5 for r in results]
metrics = compute_calibration_metrics(
confidences=confidences,
correctnesses=correctnesses,
)
return {
"ece": metrics.ece,
"mce": metrics.mce,
"brier_score": metrics.brier_score,
"overconfidence_rate": metrics.overconfidence_rate,
"underconfidence_rate": metrics.underconfidence_rate,
"mean_confidence": metrics.mean_confidence,
"mean_accuracy": metrics.mean_accuracy,
"n_samples": metrics.n_samples,
"reliability_bins": metrics.reliability_bins,
}
def main():
parser = argparse.ArgumentParser(
description="Evaluate a BioGRPO-trained model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--model", type=str, required=True,
help="Path to the GRPO-trained model (LoRA adapter directory)",
)
parser.add_argument(
"--base-model", type=str, default="mistralai/Mistral-7B-v0.3",
help="Base model name",
)
parser.add_argument(
"--sft-baseline", type=str, default=None,
help="Path to SFT baseline model for comparison",
)
parser.add_argument(
"--hold-out-tissues", type=str, nargs="+", default=["eye"],
help="Tissues held out for evaluation",
)
parser.add_argument(
"--pathway-db", type=str, default="hallmark",
help="Pathway database",
)
parser.add_argument(
"--active-verifiers", type=str, nargs="+", default=None,
help="Active verifiers (default: all)",
)
parser.add_argument(
"--max-samples", type=int, default=None,
help="Max samples to evaluate (for quick testing)",
)
parser.add_argument(
"--output", type=str, default=None,
help="Output path for results JSON",
)
parser.add_argument(
"--no-4bit", action="store_true",
help="Disable 4-bit quantization",
)
parser.add_argument(
"--sft-adapter", type=str, default=None,
help="Path to SFT LoRA adapter to merge before applying GRPO adapter (for GRPO checkpoints trained on SFT-merged base)",
)
args = parser.parse_args()
print("=" * 60)
print("BioGRPO Evaluation")
print("=" * 60)
print(f" Model: {args.model}")
print(f" Base: {args.base_model}")
print(f" Hold-out: {args.hold_out_tissues}")
print(f" SFT baseline: {args.sft_baseline or 'None'}")
print(f" Time: {datetime.now().isoformat()}")
print("=" * 60)
# Build eval dataset
print("\n[1/4] Building evaluation dataset...")
_, eval_dataset = build_grpo_dataset(
db=args.pathway_db,
hold_out_tissues=args.hold_out_tissues,
)
eval_stats = get_dataset_stats(eval_dataset)
print(f" Eval samples: {eval_stats['total']}")
print(f" By source: {eval_stats['by_source']}")
print(f" By type: {eval_stats['by_question_type']}")
# Create verifier composer
composer = VerifierComposer(active_verifiers=args.active_verifiers)
# Evaluate GRPO model
print(f"\n[2/4] Evaluating GRPO model: {args.model}")
model, tokenizer = load_model(
args.model, args.base_model, use_4bit=not args.no_4bit,
sft_adapter=args.sft_adapter,
)
grpo_results = evaluate_with_verifiers(
model, tokenizer, eval_dataset, composer,
max_samples=args.max_samples,
)
grpo_calibration = evaluate_calibration(grpo_results["per_sample"])
# Free GPU memory
del model
torch.cuda.empty_cache()
# Evaluate baseline if provided
baseline_results = None
baseline_calibration = None
if args.sft_baseline:
print(f"\n[3/4] Evaluating SFT baseline: {args.sft_baseline}")
baseline_model, baseline_tokenizer = load_model(
args.sft_baseline, args.base_model, use_4bit=not args.no_4bit,
)
baseline_results = evaluate_with_verifiers(
baseline_model, baseline_tokenizer, eval_dataset, composer,
max_samples=args.max_samples,
)
baseline_calibration = evaluate_calibration(baseline_results["per_sample"])
del baseline_model
torch.cuda.empty_cache()
else:
print("\n[3/4] Skipping baseline (not provided)")
# Print summary
print("\n[4/4] Results Summary")
print("=" * 60)
print(f"GRPO Model: {args.model}")
print(f" Mean reward: {grpo_results['mean_reward']:.3f}")
print(f" Per verifier: {grpo_results['verifier_means']}")
print(f" ECE: {grpo_calibration['ece']:.3f}")
print(f" Brier: {grpo_calibration['brier_score']:.3f}")
print(f" Overconfidence: {grpo_calibration['overconfidence_rate']:.3f}")
print(f" By type: {grpo_results['by_question_type']}")
comparison = {}
if baseline_results:
print(f"\nSFT Baseline: {args.sft_baseline}")
print(f" Mean reward: {baseline_results['mean_reward']:.3f}")
print(f" ECE: {baseline_calibration['ece']:.3f}")
print(f" Brier: {baseline_calibration['brier_score']:.3f}")
delta_reward = grpo_results["mean_reward"] - baseline_results["mean_reward"]
delta_ece = grpo_calibration["ece"] - baseline_calibration["ece"]
print(f"\n Delta reward: {delta_reward:+.3f}")
print(f" Delta ECE: {delta_ece:+.3f} (negative = better)")
comparison = {
"sft_mean_reward": baseline_results["mean_reward"],
"sft_ece": baseline_calibration["ece"],
"delta_reward": delta_reward,
"delta_ece": delta_ece,
}
# Success criteria
criteria = {
"reward_above_05": grpo_results["mean_reward"] > 0.5,
"ece_below_015": grpo_calibration["ece"] < 0.15,
}
if baseline_results:
criteria["reward_above_baseline"] = delta_reward > 0
criteria["overall_pass"] = all(criteria.values())
print(f"\nSuccess criteria: {criteria}")
# Save results
output_path = args.output or f"results/grpo_eval_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
output_data = {
"model_path": args.model,
"base_model": args.base_model,
"evaluation_date": datetime.now().isoformat(),
"hold_out_tissues": args.hold_out_tissues,
"eval_dataset_stats": eval_stats,
"grpo": {
"mean_reward": grpo_results["mean_reward"],
"verifier_means": grpo_results["verifier_means"],
"by_question_type": grpo_results["by_question_type"],
"n_samples": grpo_results["n_samples"],
},
"calibration": grpo_calibration,
"baseline_comparison": comparison,
"success_criteria": criteria,
"per_sample": grpo_results["per_sample"],
}
with open(output_path, "w") as f:
json.dump(output_data, f, indent=2)
print(f"\nResults saved to: {output_path}")
print("=" * 60)
if __name__ == "__main__":
main()