dream-s1k-demo / app_min.py
况兑
stabilize: cache to /data + bool attn_mask + minimal app
b4c6154
import os, torch, gradio as gr
from transformers import AutoModel, AutoTokenizer
# 持久化缓存到 /data
BASE = "/data"
os.makedirs(BASE, exist_ok=True)
os.environ.setdefault("HF_HOME", f"{BASE}/hf_home")
os.environ.setdefault("HF_HUB_CACHE", f"{BASE}/hf_home/hub")
os.environ.setdefault("TRANSFORMERS_CACHE", f"{BASE}/hf_home/transformers")
os.environ.setdefault("XDG_CACHE_HOME", f"{BASE}/hf_home")
MODEL_ID = os.getenv("MODEL_ID", "Dream-org/Dream-v0-Instruct-7B")
REV = os.getenv("REV", None)
print(f"[INFO] Using MODEL_ID={MODEL_ID} REV={REV or '(latest)'}")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
print("[INFO] Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV)
print("[INFO] Loading model...")
model = AutoModel.from_pretrained(
MODEL_ID, trust_remote_code=True, torch_dtype=dtype, revision=REV
).to(device).eval()
def quick_infer(q: str):
if not q.strip(): return ""
messages = [{"role": "user", "content": q}]
enc = tok.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True)
input_ids = enc.input_ids.to(device)
attention_mask = enc.attention_mask.to(device).bool() # 关键:转成 bool
with torch.no_grad():
out = model.diffusion_generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=64,
steps=64,
temperature=0.0,
return_dict_in_generate=True,
)
text = tok.decode(out.sequences[0][input_ids.shape[1]:], skip_special_tokens=True).strip()
return text
def self_check():
try:
msgs = [{"role":"system","content":"只输出一个数字"},{"role":"user","content":"Compute: 1+1"}]
enc = tok.apply_chat_template(msgs, return_tensors="pt", return_dict=True, add_generation_prompt=False)
_ = model(input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device).bool())
return "OK: forward() 可用(Dream 未必提供 labels->loss,属正常)"
except Exception as e:
return f"ERR: {repr(e)}"
with gr.Blocks() as demo:
gr.Markdown("## Dream Minimal App \n- 先点 Self-check \n- 再试一次推理")
with gr.Row():
btn = gr.Button("Self-check")
out = gr.Textbox(label="Result")
btn.click(fn=self_check, inputs=None, outputs=out)
with gr.Row():
q = gr.Textbox(label="Prompt", value="Compute: 1+1")
a = gr.Textbox(label="Output")
go = gr.Button("Generate")
go.click(fn=quick_infer, inputs=q, outputs=a)
if __name__ == "__main__":
demo.launch()