Spaces:
Paused
Paused
| import os, torch, gradio as gr | |
| from transformers import AutoModel, AutoTokenizer | |
| # 持久化缓存到 /data | |
| BASE = "/data" | |
| os.makedirs(BASE, exist_ok=True) | |
| os.environ.setdefault("HF_HOME", f"{BASE}/hf_home") | |
| os.environ.setdefault("HF_HUB_CACHE", f"{BASE}/hf_home/hub") | |
| os.environ.setdefault("TRANSFORMERS_CACHE", f"{BASE}/hf_home/transformers") | |
| os.environ.setdefault("XDG_CACHE_HOME", f"{BASE}/hf_home") | |
| MODEL_ID = os.getenv("MODEL_ID", "Dream-org/Dream-v0-Instruct-7B") | |
| REV = os.getenv("REV", None) | |
| print(f"[INFO] Using MODEL_ID={MODEL_ID} REV={REV or '(latest)'}") | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("[INFO] Loading tokenizer...") | |
| tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, revision=REV) | |
| print("[INFO] Loading model...") | |
| model = AutoModel.from_pretrained( | |
| MODEL_ID, trust_remote_code=True, torch_dtype=dtype, revision=REV | |
| ).to(device).eval() | |
| def quick_infer(q: str): | |
| if not q.strip(): return "" | |
| messages = [{"role": "user", "content": q}] | |
| enc = tok.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True) | |
| input_ids = enc.input_ids.to(device) | |
| attention_mask = enc.attention_mask.to(device).bool() # 关键:转成 bool | |
| with torch.no_grad(): | |
| out = model.diffusion_generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=64, | |
| steps=64, | |
| temperature=0.0, | |
| return_dict_in_generate=True, | |
| ) | |
| text = tok.decode(out.sequences[0][input_ids.shape[1]:], skip_special_tokens=True).strip() | |
| return text | |
| def self_check(): | |
| try: | |
| msgs = [{"role":"system","content":"只输出一个数字"},{"role":"user","content":"Compute: 1+1"}] | |
| enc = tok.apply_chat_template(msgs, return_tensors="pt", return_dict=True, add_generation_prompt=False) | |
| _ = model(input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device).bool()) | |
| return "OK: forward() 可用(Dream 未必提供 labels->loss,属正常)" | |
| except Exception as e: | |
| return f"ERR: {repr(e)}" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Dream Minimal App \n- 先点 Self-check \n- 再试一次推理") | |
| with gr.Row(): | |
| btn = gr.Button("Self-check") | |
| out = gr.Textbox(label="Result") | |
| btn.click(fn=self_check, inputs=None, outputs=out) | |
| with gr.Row(): | |
| q = gr.Textbox(label="Prompt", value="Compute: 1+1") | |
| a = gr.Textbox(label="Output") | |
| go = gr.Button("Generate") | |
| go.click(fn=quick_infer, inputs=q, outputs=a) | |
| if __name__ == "__main__": | |
| demo.launch() | |