File size: 4,704 Bytes
e45d7fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json, re, argparse, torch
import re

def _collapse_digit_separators(s: str) -> str:
    # 去掉出现在“数字 与 数字”之间的空白/逗号(含常见窄空格)
    return re.sub(r'(?<=\d)[,\s\u00A0\u202F\u2009\u2007\u2060]+(?=\d)', '', s)


def _postproc_model_text(s: str) -> str:
    # collapse spaces/commas between digits like '330.7 6' -> '330.76', '1,234' -> '1234'
    s = re.sub(r'(?<=\d)[,\s]+(?=\d)', '', s)
    return s


def _preprocess_user_text(s: str) -> str:
    # 全角标点 -> 半角
    s = s.replace(",", ",").replace("。", ".").replace(":", ":").replace("(","(").replace(")",")")
    # 去掉数字内部的逗号/空格(保留小数点)
    s = re.sub(r'(?<=\d)[,\s]+(?=\d)', '', s)
    # 压缩空白
    s = re.sub(r'\s+', ' ', s).strip()
    return s

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

def messages_to_pairs(messages):
    pairs, buf = [], []
    for m in messages:
        if m.get("role")=="user":
            buf.append(m.get("content",""))
        elif m.get("role")=="assistant" and buf:
            pairs.append({"prompt":"\n\n".join(buf), "response":m.get("content","")})
            buf=[]
    return pairs

def normalize(s: str) -> str:
    s = s.replace("\u3000"," ").strip()
    trans = str.maketrans(",。:!?【】()%+-×÷=“”‘’", ",.:!?[]()%+-*/=\"\"''")
    s = s.translate(trans)
    s = re.sub(r"\s+", " ", s)
    return s

def to_num(x):
    try:
        return float(x)
    except:
        if not isinstance(x, str):
            x = str(x)
        x = _collapse_digit_separators(x)
        m = re.search(r"[-+]?\d*\.?\d+(?:e[-+]?\d+)?", x, re.I)
        return float(m.group(0)) if m else None

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base", default="Qwen/Qwen2.5-0.5B-Instruct")
    ap.add_argument("--adapter", required=True)
    ap.add_argument("--data", required=True)
    ap.add_argument("--max_new", type=int, default=64)
    ap.add_argument("--limit", type=int, default=0, help="只评测前 N 条,0=全部")
    args = ap.parse_args()

    tok = AutoTokenizer.from_pretrained(args.base, trust_remote_code=True, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    base = AutoModelForCausalLM.from_pretrained(args.base, trust_remote_code=True, torch_dtype=torch.float32)
    model = PeftModel.from_pretrained(base, args.adapter)
    model.eval()
    device = "mps" if torch.backends.mps.is_available() else "cpu"
    model.to(device)

    # 清理 generation 默认的采样项,避免 warning 且确保贪心
    gc = model.generation_config
    gc.do_sample = False
    gc.temperature = None
    gc.top_p = None
    gc.top_k = None

    golds, preds = [], []

    with open(args.data, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if args.limit and i >= args.limit: break
            row = json.loads(line)
            pairs = messages_to_pairs(row["messages"])
            if not pairs: continue
            ex = pairs[0]
            gold = ex["response"]

            ctx = [
                {"role":"system","content":"请只输出最终答案,不要解释。只输出一个数字;若为小数,完整输出全部小数位,不要四舍五入或截断。"},
                {"role":"user","content": ex["prompt"]}
            ]
            prompt_text = tok.apply_chat_template(ctx, tokenize=False, add_generation_prompt=True)
            inputs = tok(prompt_text, return_tensors="pt").to(device)

            with torch.no_grad():
                out = model.generate(
                    **inputs,
                    max_new_tokens=args.max_new,
                    do_sample=False,
                    temperature=None,
                    eos_token_id=tok.eos_token_id,
                    pad_token_id=tok.eos_token_id
                )
            pred = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)

            pred = _postproc_model_text(pred)
            golds.append(gold)
            preds.append(pred)

            print(f"[{i}] GT={repr(gold)} | PRED={repr(pred)}")

    # 计算三种 EM
    strict = sum(1 for g,p in zip(golds,preds) if p==g)
    loose  = sum(1 for g,p in zip(golds,preds) if normalize(p)==normalize(g))
    numem  = 0
    for g,p in zip(golds,preds):
        ng, np = to_num(g), to_num(p)
        if ng is not None and np is not None and abs(ng-np)<1e-6:
            numem += 1

    n = len(golds) if golds else 1
    print(f"\n==> EM strict={strict/n:.3f}  EM loose={loose/n:.3f}  EM numeric={numem/n:.3f}  (N={n})")

if __name__ == "__main__":
    main()