Spaces:
Paused
Paused
File size: 3,493 Bytes
8d99ef1 8d0f2c1 011e732 8d0f2c1 8d99ef1 85a23a9 57853f1 8d99ef1 57853f1 85a23a9 57853f1 f421454 57853f1 f421454 8d99ef1 f421454 85a23a9 8d99ef1 57853f1 8d99ef1 85a23a9 57853f1 f421454 8d99ef1 f421454 85a23a9 8d99ef1 85a23a9 8d99ef1 57853f1 f421454 8d99ef1 f421454 57853f1 8d99ef1 57853f1 85a23a9 8d99ef1 85a23a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import re, os, gradio as gr, torch
from transformers import AutoModel, AutoTokenizer # 关键:AutoModel(不是 AutoModelForCausalLM)
# 直接使用 Qwen 模型
MODEL_ID = "kudui/qwen-s1k-overfit10"
REV = None
print(f"[INFO] Using MODEL_ID={MODEL_ID}")
VER = "v4-dream"
SYS = (
"只输出最终答案且必须是一个数字,不要任何解释或单位。"
"禁止在末尾多打一位。若题目中出现小数,请按题面中最长的小数位数输出(截断/补零,不四舍五入)。"
)
def collapse_digit_seps(s: str) -> str:
s = s.replace(",", ",").replace("。", ".").replace(":", ":").replace("(", "(").replace(")", ")")
s = re.sub(r'(?<=\d)[,\s\u00A0\u202F\u2009\u2007\u2060]+(?=\d)', '', s)
return re.sub(r'\s+', ' ', s).strip()
def max_decimals_in_text(s: str) -> int:
decs = [len(m) for m in re.findall(r'\d+\.(\d+)', s)]
return max(decs) if decs else 0
def extract_number(s: str):
m = re.search(r'[-+]?\d*\.?\d+(?:e[-+]?\d+)?', s, re.I)
return m.group(0) if m else None
print(f"Loading tokenizer from {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV)
print(f"Loading Dream model from {MODEL_ID}...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained(
MODEL_ID,
trust_remote_code=True,
revision=REV,
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
).to(device).eval()
def predict(q: str):
if not q.strip():
return ""
qn = collapse_digit_seps(q)
want_dec = max_decimals_in_text(qn)
messages = [{"role":"system","content":SYS},{"role":"user","content":qn}]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
with torch.no_grad():
out = model.diffusion_generate( # 关键:diffusion_generate(不是 generate)
input_ids,
attention_mask=attention_mask,
max_new_tokens=24, # 更稳更省
steps=32, # 32/64/128 越大越稳
temperature=0.0,
top_p=None,
return_dict_in_generate=True
)
raw = tokenizer.decode(out.sequences[0][input_ids.shape[1]:].tolist(), skip_special_tokens=True).strip()
raw = collapse_digit_seps(raw)
num = extract_number(raw)
if num is None:
return raw
if 'e' not in num.lower():
if want_dec == 0:
try:
num = str(int(float(num)))
except Exception:
pass
else:
sgn = '-' if num.startswith('-') else ''
body = num.lstrip('+-')
if '.' in body:
intp, frac = body.split('.', 1)
frac = (frac + '0'*want_dec)[:want_dec]
else:
intp, frac = body, '0'*want_dec
num = sgn + intp + ('.'+frac if want_dec>0 else '')
return num
with gr.Blocks() as demo:
gr.Markdown(f"# Dream S1K Demo \n**Model:** `{MODEL_ID}` • **Version:** `{VER}`")
inp = gr.Textbox(label="Question", lines=4, placeholder="例如:Compute: 330.76 + 0.00")
btn = gr.Button("Run")
out = gr.Textbox(label="Answer")
btn.click(fn=predict, inputs=inp, outputs=out)
if __name__ == "__main__":
demo.launch()
|