import json, re, argparse, torch import re def _collapse_digit_separators(s: str) -> str: # 去掉出现在“数字 与 数字”之间的空白/逗号(含常见窄空格) return re.sub(r'(?<=\d)[,\s\u00A0\u202F\u2009\u2007\u2060]+(?=\d)', '', s) def _postproc_model_text(s: str) -> str: # collapse spaces/commas between digits like '330.7 6' -> '330.76', '1,234' -> '1234' s = re.sub(r'(?<=\d)[,\s]+(?=\d)', '', s) return s def _preprocess_user_text(s: str) -> str: # 全角标点 -> 半角 s = s.replace(",", ",").replace("。", ".").replace(":", ":").replace("(","(").replace(")",")") # 去掉数字内部的逗号/空格(保留小数点) s = re.sub(r'(?<=\d)[,\s]+(?=\d)', '', s) # 压缩空白 s = re.sub(r'\s+', ' ', s).strip() return s from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel def messages_to_pairs(messages): pairs, buf = [], [] for m in messages: if m.get("role")=="user": buf.append(m.get("content","")) elif m.get("role")=="assistant" and buf: pairs.append({"prompt":"\n\n".join(buf), "response":m.get("content","")}) buf=[] return pairs def normalize(s: str) -> str: s = s.replace("\u3000"," ").strip() trans = str.maketrans(",。:!?【】()%+-×÷=“”‘’", ",.:!?[]()%+-*/=\"\"''") s = s.translate(trans) s = re.sub(r"\s+", " ", s) return s def to_num(x): try: return float(x) except: if not isinstance(x, str): x = str(x) x = _collapse_digit_separators(x) m = re.search(r"[-+]?\d*\.?\d+(?:e[-+]?\d+)?", x, re.I) return float(m.group(0)) if m else None def main(): ap = argparse.ArgumentParser() ap.add_argument("--base", default="Qwen/Qwen2.5-0.5B-Instruct") ap.add_argument("--adapter", required=True) ap.add_argument("--data", required=True) ap.add_argument("--max_new", type=int, default=64) ap.add_argument("--limit", type=int, default=0, help="只评测前 N 条,0=全部") args = ap.parse_args() tok = AutoTokenizer.from_pretrained(args.base, trust_remote_code=True, use_fast=True) if tok.pad_token is None: tok.pad_token = tok.eos_token base = AutoModelForCausalLM.from_pretrained(args.base, trust_remote_code=True, torch_dtype=torch.float32) model = PeftModel.from_pretrained(base, args.adapter) model.eval() device = "mps" if torch.backends.mps.is_available() else "cpu" model.to(device) # 清理 generation 默认的采样项,避免 warning 且确保贪心 gc = model.generation_config gc.do_sample = False gc.temperature = None gc.top_p = None gc.top_k = None golds, preds = [], [] with open(args.data, "r", encoding="utf-8") as f: for i, line in enumerate(f): if args.limit and i >= args.limit: break row = json.loads(line) pairs = messages_to_pairs(row["messages"]) if not pairs: continue ex = pairs[0] gold = ex["response"] ctx = [ {"role":"system","content":"请只输出最终答案,不要解释。只输出一个数字;若为小数,完整输出全部小数位,不要四舍五入或截断。"}, {"role":"user","content": ex["prompt"]} ] prompt_text = tok.apply_chat_template(ctx, tokenize=False, add_generation_prompt=True) inputs = tok(prompt_text, return_tensors="pt").to(device) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=args.max_new, do_sample=False, temperature=None, eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id ) pred = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) pred = _postproc_model_text(pred) golds.append(gold) preds.append(pred) print(f"[{i}] GT={repr(gold)} | PRED={repr(pred)}") # 计算三种 EM strict = sum(1 for g,p in zip(golds,preds) if p==g) loose = sum(1 for g,p in zip(golds,preds) if normalize(p)==normalize(g)) numem = 0 for g,p in zip(golds,preds): ng, np = to_num(g), to_num(p) if ng is not None and np is not None and abs(ng-np)<1e-6: numem += 1 n = len(golds) if golds else 1 print(f"\n==> EM strict={strict/n:.3f} EM loose={loose/n:.3f} EM numeric={numem/n:.3f} (N={n})") if __name__ == "__main__": main()