import os, torch, gradio as gr from transformers import AutoModel, AutoTokenizer # 持久化缓存到 /data BASE = "/data" os.makedirs(BASE, exist_ok=True) os.environ.setdefault("HF_HOME", f"{BASE}/hf_home") os.environ.setdefault("HF_HUB_CACHE", f"{BASE}/hf_home/hub") os.environ.setdefault("TRANSFORMERS_CACHE", f"{BASE}/hf_home/transformers") os.environ.setdefault("XDG_CACHE_HOME", f"{BASE}/hf_home") MODEL_ID = os.getenv("MODEL_ID", "Dream-org/Dream-v0-Instruct-7B") REV = os.getenv("REV", None) print(f"[INFO] Using MODEL_ID={MODEL_ID} REV={REV or '(latest)'}") dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 device = "cuda" if torch.cuda.is_available() else "cpu" print("[INFO] Loading tokenizer...") tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV) print("[INFO] Loading model...") model = AutoModel.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=dtype, revision=REV ).to(device).eval() def quick_infer(q: str): if not q.strip(): return "" messages = [{"role": "user", "content": q}] enc = tok.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True) input_ids = enc.input_ids.to(device) attention_mask = enc.attention_mask.to(device).bool() # 关键:转成 bool with torch.no_grad(): out = model.diffusion_generate( input_ids, attention_mask=attention_mask, max_new_tokens=64, steps=64, temperature=0.0, return_dict_in_generate=True, ) text = tok.decode(out.sequences[0][input_ids.shape[1]:], skip_special_tokens=True).strip() return text def self_check(): try: msgs = [{"role":"system","content":"只输出一个数字"},{"role":"user","content":"Compute: 1+1"}] enc = tok.apply_chat_template(msgs, return_tensors="pt", return_dict=True, add_generation_prompt=False) _ = model(input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device).bool()) return "OK: forward() 可用(Dream 未必提供 labels->loss,属正常)" except Exception as e: return f"ERR: {repr(e)}" with gr.Blocks() as demo: gr.Markdown("## Dream Minimal App \n- 先点 Self-check \n- 再试一次推理") with gr.Row(): btn = gr.Button("Self-check") out = gr.Textbox(label="Result") btn.click(fn=self_check, inputs=None, outputs=out) with gr.Row(): q = gr.Textbox(label="Prompt", value="Compute: 1+1") a = gr.Textbox(label="Output") go = gr.Button("Generate") go.click(fn=quick_infer, inputs=q, outputs=a) if __name__ == "__main__": demo.launch()