Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """dhara-chat β CPU demo of the dhara-250M tri-mode model. | |
| "Denoiser" terminal aesthetic. Chat lets you pick the decoding mode (AR types | |
| left-to-right; diffusion visibly unmasks βββ blocks). The tri-mode compare tab | |
| streams all three modes so you can watch AR type, block-diffusion denoise, and | |
| self-speculation jump in accepted spans β each with tokens/sec. | |
| """ | |
| import os, re, time, threading | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| MODEL_ID = os.environ.get("DHARA_MODEL", "codelion/dhara-250m") | |
| TOKEN = os.environ.get("HF_TOKEN") | |
| tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, trust_remote_code=True, torch_dtype=torch.float32, token=TOKEN).eval() | |
| # INT8 dynamic quant of Linear layers -> ~2x faster on CPU; quality preserved (verified across all 3 modes) | |
| model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8).eval() | |
| IM_END = tok.convert_tokens_to_ids("<|im_end|>") | |
| MASK = int(model.config.mask_token_id) | |
| GEN = dict(do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.2, no_repeat_ngram_size=3) | |
| GEN_GREEDY = dict(do_sample=False, repetition_penalty=1.3, no_repeat_ngram_size=3) | |
| REP = 1.3 # repetition penalty for diffusion/self-spec unmasking (prevents "capital capital" collapse) | |
| def _msg_text(c): | |
| # gradio 6 Chatbot stores content as a list of parts; flatten to plain text for the template | |
| if isinstance(c, list): | |
| return "".join((p.get("text") or "") if isinstance(p, dict) else str(p) for p in c) | |
| return c if isinstance(c, str) else str(c) | |
| def _enc(messages, max_tok=1024): | |
| msgs = [{"role": m.get("role", "user"), "content": _msg_text(m.get("content", ""))} for m in messages] | |
| while True: # sliding window: drop oldest turns until the prompt fits the 1k budget | |
| p = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| e = tok(p, return_tensors="pt", add_special_tokens=False) | |
| if e.input_ids.shape[1] <= max_tok or len(msgs) <= 1: | |
| return e.input_ids, e.attention_mask | |
| msgs = msgs[1:] | |
| def _ntok(text): | |
| return len(tok(text, add_special_tokens=False).input_ids) | |
| def _block_mask(S, bl, dev, dt): | |
| idx = torch.arange(S, device=dev) | |
| allowed = (idx // bl).unsqueeze(0) <= (idx // bl).unsqueeze(1) | |
| return torch.zeros((S, S), device=dev, dtype=dt).masked_fill(~allowed, float("-inf"))[None, None] | |
| def _rep_pen(logits, seen, penalty=REP): | |
| if penalty == 1.0 or seen.numel() == 0: | |
| return logits | |
| u = torch.unique(seen) | |
| s = logits[:, u] | |
| logits[:, u] = torch.where(s > 0, s / penalty, s * penalty) | |
| return logits | |
| def _clip(text): | |
| """Trim a trailing incomplete sentence so responses never end mid-word.""" | |
| t = text.rstrip() | |
| ends = [m.start() for m in re.finditer(r"[.!?](?=\s|$)", t)] | |
| if ends and ends[-1] >= 16: | |
| t = t[:ends[-1] + 1] | |
| return re.sub(r"\s+\d+\.$", "", t).rstrip() | |
| def _render(row, prompt_len): | |
| out = [] | |
| for t in row[prompt_len:]: | |
| if t == MASK: | |
| out.append("β") | |
| elif t == IM_END: | |
| break | |
| else: | |
| out.append(tok.decode([t])) | |
| return "".join(out) | |
| def _ar_stream(ids, am, max_new, gen=GEN): | |
| streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) | |
| threading.Thread(target=model.generate, kwargs=dict( | |
| input_ids=ids, attention_mask=am, max_new_tokens=max_new, eos_token_id=IM_END, | |
| pad_token_id=IM_END, streamer=streamer, **gen)).start() | |
| out = "" | |
| for tk_ in streamer: | |
| out += tk_ | |
| yield out | |
| def _diffusion_stream(ids, block_len=32, threshold=0.5, max_new=64): | |
| dev = ids.device; dt = next(model.parameters()).dtype | |
| cur = ids; gen = 0 | |
| while gen < max_new: | |
| seq = torch.cat([cur, torch.full((1, block_len), MASK, device=dev)], 1); S = seq.shape[1] | |
| bias = _block_mask(S, block_len, dev, dt) | |
| for _ in range(block_len): | |
| mp = (seq[0] == MASK).nonzero(as_tuple=True)[0] | |
| if mp.numel() == 0: | |
| break | |
| lg = model(input_ids=seq, trimode_bias=bias).logits[0].float() | |
| lgm = _rep_pen(lg[mp], seq[0][seq[0] != MASK]) | |
| conf, pred = F.softmax(lgm, -1).max(-1) | |
| take = conf >= threshold | |
| if take.sum() == 0: | |
| take[conf.argmax()] = True | |
| seq[0, mp[take]] = pred[take] | |
| yield _render(seq[0].tolist(), ids.shape[1]) | |
| cur = seq; gen += block_len | |
| if (cur[0, -block_len:] == IM_END).any(): | |
| break | |
| def _selfspec_stream(ids, k=8, block_len=32, max_new=48): | |
| dev = ids.device; dt = next(model.parameters()).dtype | |
| cur = ids; gen = 0 | |
| while gen < max_new: | |
| n = cur.shape[1] | |
| seq = torch.cat([cur, torch.full((1, k), MASK, device=dev)], 1); S = seq.shape[1] | |
| seen = cur[0] | |
| dl = model(input_ids=seq, trimode_bias=_block_mask(S, block_len, dev, dt)).logits[0].float() | |
| draft = _rep_pen(dl[n:n + k], seen).argmax(-1) | |
| cand = torch.cat([cur, draft.unsqueeze(0)], 1) | |
| al = model(input_ids=cand).logits[0].float() | |
| ar_pred = _rep_pen(al[n - 1:n + k - 1], seen).argmax(-1) | |
| match = (draft == ar_pred) | |
| m = int((~match).float().argmax().item()) if (~match).any() else k | |
| new = torch.cat([draft[:m], ar_pred[m:m + 1]]) if m < k else torch.cat([draft, al[n + k - 1:n + k].argmax(-1)]) | |
| cur = torch.cat([cur, new.unsqueeze(0)], 1); gen += new.numel() | |
| yield tok.decode(cur[0, ids.shape[1]:], skip_special_tokens=True) | |
| if IM_END in new.tolist(): | |
| break | |
| def _pane(label, speed, text): | |
| sp = f" <span class='dh-s'>{speed}</span>" if speed else "" | |
| return f"<div class='dh-pane'><span class='dh-k'>{label}</span>{sp}<div class='dh-t'>{text}</div></div>" | |
| def _pending(label): | |
| return f"<div class='dh-pane dh-run'><span class='dh-k'>{label}</span> <span class='dh-mask'>βββ waitingβ¦</span></div>" | |
| def compare(prompt): | |
| ids, am = _enc([{"role": "user", "content": prompt}]) | |
| a = d = s = "" | |
| t0 = time.time() | |
| for a in _ar_stream(ids, am, 48, GEN_GREEDY): | |
| yield _pane("AR", "β³", a), _pending("BLOCK-DIFFUSION"), _pending("SELF-SPEC") | |
| a = _clip(a); a_s = f"{_ntok(a)/max(1e-9,time.time()-t0):.1f} tok/s" | |
| t0 = time.time() | |
| for d in _diffusion_stream(ids, max_new=48): | |
| yield _pane("AR", a_s, a), _pane("BLOCK-DIFFUSION", "β³", d), _pending("SELF-SPEC") | |
| d = _clip(d); d_s = f"{_ntok(d)/max(1e-9,time.time()-t0):.1f} tok/s" | |
| t0 = time.time() | |
| for s in _selfspec_stream(ids, max_new=48): | |
| yield _pane("AR", a_s, a), _pane("BLOCK-DIFFUSION", d_s, d), _pane("SELF-SPEC", "β³", s) | |
| s = _clip(s); s_s = f"{_ntok(s)/max(1e-9,time.time()-t0):.1f} tok/s Β· AR-quality" | |
| yield _pane("AR", a_s, a), _pane("BLOCK-DIFFUSION", d_s, d), _pane("SELF-SPEC", s_s, s) | |
| CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;500;600&display=swap'); | |
| body, .gradio-container, gradio-app { background:#0b0b0c !important; color:#e9e2d4 !important; | |
| font-family:'IBM Plex Mono', ui-monospace, monospace !important; } | |
| .gradio-container { max-width:920px !important; margin:0 auto !important; } | |
| .dhara-hero { text-align:center; padding:30px 12px 14px; border-bottom:1px solid #26261b; } | |
| .dhara-glyphs { color:#ffb000; opacity:.3; letter-spacing:.55em; font-size:13px; animation:dh-diffuse 3.4s ease-in-out infinite; } | |
| .dhara-title { font-size:56px; font-weight:600; letter-spacing:.2em; color:#ffb000; text-shadow:0 0 22px rgba(255,176,0,.4); margin:4px 0 2px; } | |
| .dhara-tag { color:#8f8a78; font-size:12.5px; letter-spacing:.14em; text-transform:uppercase; } | |
| @keyframes dh-diffuse { 0%,100%{opacity:.14; filter:blur(.7px)} 50%{opacity:.46; filter:blur(0)} } | |
| button.primary, button[variant="primary"], .primary { background:#1c1609 !important; color:#ffb000 !important; | |
| border:1px solid #ffb000 !important; box-shadow:none !important; text-transform:uppercase; letter-spacing:.08em; } | |
| .dh-pane { border:1px solid #26261b; border-left:2px solid #ffb000; border-radius:6px; padding:12px 15px; background:#0f0f10; margin-bottom:10px; } | |
| .dh-k { color:#ffb000; font-weight:600; letter-spacing:.1em; } | |
| .dh-s { color:#8f8a78; font-size:12px; } | |
| .dh-t { white-space:pre-wrap; margin-top:8px; color:#e9e2d4; } | |
| .dh-mask { color:#7a7568; animation:dh-diffuse 1.1s ease-in-out infinite; } | |
| .dh-chat { height:400px !important; max-height:400px !important; overflow-y:auto !important; } | |
| .dh-ctl { gap:14px; align-items:flex-end; margin-bottom:6px; } | |
| .dhara-link { margin-top:8px; font-size:12px; } | |
| .dhara-link a { color:#ffb000; text-decoration:none; border-bottom:1px dotted #ffb000; } | |
| footer { display:none !important; } | |
| """ | |
| HEADER = """ | |
| <div class="dhara-hero"> | |
| <div class="dhara-glyphs">βββ βββ βββ βββ βββ βββ</div> | |
| <div class="dhara-title">dhara</div> | |
| <div class="dhara-tag">tri-mode Β· 250M Β· ar / block-diffusion / self-speculation</div> | |
| <div class="dhara-link">model: <a href="https://huggingface.co/codelion/dhara-250m" target="_blank">codelion/dhara-250m β</a></div> | |
| </div> | |
| """ | |
| with gr.Blocks(title="dhara-chat") as demo: | |
| gr.HTML(HEADER) | |
| with gr.Tab("chat"): | |
| with gr.Row(elem_classes="dh-ctl"): | |
| mode = gr.Radio(["AR", "diffusion", "self-spec"], value="AR", label="mode (diffusion shows βββ denoising; self-spec = AR-quality)", scale=3) | |
| max_new = gr.Number(value=80, label="max tokens", precision=0, minimum=16, maximum=256, scale=1) | |
| temp = gr.Number(value=0.7, label="temp (AR)", minimum=0.1, maximum=1.5, scale=1) | |
| chatbot = gr.Chatbot(height=400, show_label=False, elem_classes="dh-chat") | |
| with gr.Row(): | |
| msg = gr.Textbox(placeholder="Message dharaβ¦", show_label=False, container=False, scale=7) | |
| send = gr.Button("send", variant="primary", scale=1) | |
| clear = gr.Button("new chat", scale=1) | |
| gr.Examples(["Write a short paragraph about the ocean.", | |
| "Explain what a neural network is in simple terms.", | |
| "What is exercise good for?", | |
| "Why is sleep important?"], | |
| inputs=msg, outputs=chatbot, fn=lambda _q: [], run_on_click=True, | |
| cache_examples=False, label="try a prompt (starts a new chat)") | |
| def _user(message, history): | |
| return "", (history or []) + [{"role": "user", "content": message}] | |
| def _bot(history, mode, max_new, temperature): | |
| ids, am = _enc(history) | |
| last = "" | |
| if mode == "diffusion": | |
| for last in _diffusion_stream(ids, max_new=int(max_new)): | |
| yield history + [{"role": "assistant", "content": last}] | |
| elif mode == "self-spec": | |
| for last in _selfspec_stream(ids, max_new=int(max_new)): | |
| yield history + [{"role": "assistant", "content": last}] | |
| else: | |
| for last in _ar_stream(ids, am, int(max_new), dict(GEN, temperature=float(temperature))): | |
| yield history + [{"role": "assistant", "content": last}] | |
| yield history + [{"role": "assistant", "content": _clip(last)}] | |
| msg.submit(_user, [msg, chatbot], [msg, chatbot]).then(_bot, [chatbot, mode, max_new, temp], chatbot) | |
| send.click(_user, [msg, chatbot], [msg, chatbot]).then(_bot, [chatbot, mode, max_new, temp], chatbot) | |
| clear.click(lambda: ([], ""), None, [chatbot, msg]) | |
| with gr.Tab("tri-mode compare"): | |
| gr.Markdown("Same prompt, three decoding modes β watch **AR** type, **block-diffusion** denoise `βββ`, and **self-spec** jump in accepted spans. \n<span style='color:#8f8a78'>config: greedy Β· repetition-penalty 1.3 Β· 48 tokens/mode Β· clipped to last full sentence</span>") | |
| inp = gr.Textbox(label="prompt", value="Write a short paragraph about the ocean.") | |
| gr.Examples(["Write a short paragraph about the ocean.", | |
| "Explain what a neural network is in simple terms.", | |
| "What is exercise good for?"], | |
| inputs=inp) | |
| btn = gr.Button("run all 3 modes", variant="primary") | |
| o1 = gr.HTML(); o2 = gr.HTML(); o3 = gr.HTML() | |
| btn.click(compare, inp, [o1, o2, o3]) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860, | |
| theme=gr.themes.Base(font=gr.themes.GoogleFont("IBM Plex Mono")), css=CSS) | |