Spaces:
Paused
Paused
况兑
eval: greedy decode + numeric strict; system: force full decimals; regressions: A/B/C/noisy
e45d7fc
| import json, re, argparse, torch | |
| import re | |
| def _collapse_digit_separators(s: str) -> str: | |
| # 去掉出现在“数字 与 数字”之间的空白/逗号(含常见窄空格) | |
| return re.sub(r'(?<=\d)[,\s\u00A0\u202F\u2009\u2007\u2060]+(?=\d)', '', s) | |
| def _postproc_model_text(s: str) -> str: | |
| # collapse spaces/commas between digits like '330.7 6' -> '330.76', '1,234' -> '1234' | |
| s = re.sub(r'(?<=\d)[,\s]+(?=\d)', '', s) | |
| return s | |
| def _preprocess_user_text(s: str) -> str: | |
| # 全角标点 -> 半角 | |
| s = s.replace(",", ",").replace("。", ".").replace(":", ":").replace("(","(").replace(")",")") | |
| # 去掉数字内部的逗号/空格(保留小数点) | |
| s = re.sub(r'(?<=\d)[,\s]+(?=\d)', '', s) | |
| # 压缩空白 | |
| s = re.sub(r'\s+', ' ', s).strip() | |
| return s | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| def messages_to_pairs(messages): | |
| pairs, buf = [], [] | |
| for m in messages: | |
| if m.get("role")=="user": | |
| buf.append(m.get("content","")) | |
| elif m.get("role")=="assistant" and buf: | |
| pairs.append({"prompt":"\n\n".join(buf), "response":m.get("content","")}) | |
| buf=[] | |
| return pairs | |
| def normalize(s: str) -> str: | |
| s = s.replace("\u3000"," ").strip() | |
| trans = str.maketrans(",。:!?【】()%+-×÷=“”‘’", ",.:!?[]()%+-*/=\"\"''") | |
| s = s.translate(trans) | |
| s = re.sub(r"\s+", " ", s) | |
| return s | |
| def to_num(x): | |
| try: | |
| return float(x) | |
| except: | |
| if not isinstance(x, str): | |
| x = str(x) | |
| x = _collapse_digit_separators(x) | |
| m = re.search(r"[-+]?\d*\.?\d+(?:e[-+]?\d+)?", x, re.I) | |
| return float(m.group(0)) if m else None | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--base", default="Qwen/Qwen2.5-0.5B-Instruct") | |
| ap.add_argument("--adapter", required=True) | |
| ap.add_argument("--data", required=True) | |
| ap.add_argument("--max_new", type=int, default=64) | |
| ap.add_argument("--limit", type=int, default=0, help="只评测前 N 条,0=全部") | |
| args = ap.parse_args() | |
| tok = AutoTokenizer.from_pretrained(args.base, trust_remote_code=True, use_fast=True) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| base = AutoModelForCausalLM.from_pretrained(args.base, trust_remote_code=True, torch_dtype=torch.float32) | |
| model = PeftModel.from_pretrained(base, args.adapter) | |
| model.eval() | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| model.to(device) | |
| # 清理 generation 默认的采样项,避免 warning 且确保贪心 | |
| gc = model.generation_config | |
| gc.do_sample = False | |
| gc.temperature = None | |
| gc.top_p = None | |
| gc.top_k = None | |
| golds, preds = [], [] | |
| with open(args.data, "r", encoding="utf-8") as f: | |
| for i, line in enumerate(f): | |
| if args.limit and i >= args.limit: break | |
| row = json.loads(line) | |
| pairs = messages_to_pairs(row["messages"]) | |
| if not pairs: continue | |
| ex = pairs[0] | |
| gold = ex["response"] | |
| ctx = [ | |
| {"role":"system","content":"请只输出最终答案,不要解释。只输出一个数字;若为小数,完整输出全部小数位,不要四舍五入或截断。"}, | |
| {"role":"user","content": ex["prompt"]} | |
| ] | |
| prompt_text = tok.apply_chat_template(ctx, tokenize=False, add_generation_prompt=True) | |
| inputs = tok(prompt_text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=args.max_new, | |
| do_sample=False, | |
| temperature=None, | |
| eos_token_id=tok.eos_token_id, | |
| pad_token_id=tok.eos_token_id | |
| ) | |
| pred = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) | |
| pred = _postproc_model_text(pred) | |
| golds.append(gold) | |
| preds.append(pred) | |
| print(f"[{i}] GT={repr(gold)} | PRED={repr(pred)}") | |
| # 计算三种 EM | |
| strict = sum(1 for g,p in zip(golds,preds) if p==g) | |
| loose = sum(1 for g,p in zip(golds,preds) if normalize(p)==normalize(g)) | |
| numem = 0 | |
| for g,p in zip(golds,preds): | |
| ng, np = to_num(g), to_num(p) | |
| if ng is not None and np is not None and abs(ng-np)<1e-6: | |
| numem += 1 | |
| n = len(golds) if golds else 1 | |
| print(f"\n==> EM strict={strict/n:.3f} EM loose={loose/n:.3f} EM numeric={numem/n:.3f} (N={n})") | |
| if __name__ == "__main__": | |
| main() | |