File size: 3,493 Bytes
8d99ef1
 
 
8d0f2c1
011e732
8d0f2c1
 
8d99ef1
 
85a23a9
57853f1
 
8d99ef1
57853f1
85a23a9
57853f1
 
 
 
f421454
57853f1
 
 
f421454
8d99ef1
f421454
 
85a23a9
8d99ef1
 
57853f1
8d99ef1
 
 
 
 
 
 
 
85a23a9
 
 
 
57853f1
 
f421454
8d99ef1
 
 
 
f421454
85a23a9
8d99ef1
 
 
 
 
 
 
 
85a23a9
8d99ef1
57853f1
 
f421454
8d99ef1
f421454
57853f1
 
 
 
 
 
 
 
 
 
 
8d99ef1
57853f1
 
 
 
85a23a9
 
8d99ef1
85a23a9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import re, os, gradio as gr, torch
from transformers import AutoModel, AutoTokenizer  # 关键:AutoModel(不是 AutoModelForCausalLM)

# 直接使用 Qwen 模型
MODEL_ID = "kudui/qwen-s1k-overfit10"
REV = None
print(f"[INFO] Using MODEL_ID={MODEL_ID}")

VER = "v4-dream"

SYS = (
    "只输出最终答案且必须是一个数字,不要任何解释或单位。"
    "禁止在末尾多打一位。若题目中出现小数,请按题面中最长的小数位数输出(截断/补零,不四舍五入)。"
)

def collapse_digit_seps(s: str) -> str:
    s = s.replace(",", ",").replace("。", ".").replace(":", ":").replace("(", "(").replace(")", ")")
    s = re.sub(r'(?<=\d)[,\s\u00A0\u202F\u2009\u2007\u2060]+(?=\d)', '', s)
    return re.sub(r'\s+', ' ', s).strip()

def max_decimals_in_text(s: str) -> int:
    decs = [len(m) for m in re.findall(r'\d+\.(\d+)', s)]
    return max(decs) if decs else 0

def extract_number(s: str):
    m = re.search(r'[-+]?\d*\.?\d+(?:e[-+]?\d+)?', s, re.I)
    return m.group(0) if m else None

print(f"Loading tokenizer from {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV)

print(f"Loading Dream model from {MODEL_ID}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    revision=REV,
    torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
).to(device).eval()

def predict(q: str):
    if not q.strip():
        return ""
    qn = collapse_digit_seps(q)
    want_dec = max_decimals_in_text(qn)

    messages = [{"role":"system","content":SYS},{"role":"user","content":qn}]
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True)
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    with torch.no_grad():
        out = model.diffusion_generate(              # 关键:diffusion_generate(不是 generate)
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=24,                       # 更稳更省
            steps=32,                                # 32/64/128 越大越稳
            temperature=0.0,
            top_p=None,
            return_dict_in_generate=True
        )
    raw = tokenizer.decode(out.sequences[0][input_ids.shape[1]:].tolist(), skip_special_tokens=True).strip()
    raw = collapse_digit_seps(raw)
    num = extract_number(raw)
    if num is None:
        return raw

    if 'e' not in num.lower():
        if want_dec == 0:
            try:
                num = str(int(float(num)))
            except Exception:
                pass
        else:
            sgn = '-' if num.startswith('-') else ''
            body = num.lstrip('+-')
            if '.' in body:
                intp, frac = body.split('.', 1)
                frac = (frac + '0'*want_dec)[:want_dec]
            else:
                intp, frac = body, '0'*want_dec
            num = sgn + intp + ('.'+frac if want_dec>0 else '')
    return num

with gr.Blocks() as demo:
    gr.Markdown(f"# Dream S1K Demo  \n**Model:** `{MODEL_ID}` • **Version:** `{VER}`")
    inp = gr.Textbox(label="Question", lines=4, placeholder="例如:Compute: 330.76 + 0.00")
    btn = gr.Button("Run")
    out = gr.Textbox(label="Answer")
    btn.click(fn=predict, inputs=inp, outputs=out)

if __name__ == "__main__":
    demo.launch()