dhara-chat / app.py
codelion's picture
Multi-turn UX: new-chat button; example click resets conversation; 1k-token sliding-window history; 4th example (sleep)
a284b48 verified
#!/usr/bin/env python3
"""dhara-chat β€” CPU demo of the dhara-250M tri-mode model.
"Denoiser" terminal aesthetic. Chat lets you pick the decoding mode (AR types
left-to-right; diffusion visibly unmasks β–“β–’β–‘ blocks). The tri-mode compare tab
streams all three modes so you can watch AR type, block-diffusion denoise, and
self-speculation jump in accepted spans β€” each with tokens/sec.
"""
import os, re, time, threading
import torch
import torch.nn.functional as F
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MODEL_ID = os.environ.get("DHARA_MODEL", "codelion/dhara-250m")
TOKEN = os.environ.get("HF_TOKEN")
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=TOKEN)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, trust_remote_code=True, torch_dtype=torch.float32, token=TOKEN).eval()
# INT8 dynamic quant of Linear layers -> ~2x faster on CPU; quality preserved (verified across all 3 modes)
model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8).eval()
IM_END = tok.convert_tokens_to_ids("<|im_end|>")
MASK = int(model.config.mask_token_id)
GEN = dict(do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.2, no_repeat_ngram_size=3)
GEN_GREEDY = dict(do_sample=False, repetition_penalty=1.3, no_repeat_ngram_size=3)
REP = 1.3 # repetition penalty for diffusion/self-spec unmasking (prevents "capital capital" collapse)
def _msg_text(c):
# gradio 6 Chatbot stores content as a list of parts; flatten to plain text for the template
if isinstance(c, list):
return "".join((p.get("text") or "") if isinstance(p, dict) else str(p) for p in c)
return c if isinstance(c, str) else str(c)
def _enc(messages, max_tok=1024):
msgs = [{"role": m.get("role", "user"), "content": _msg_text(m.get("content", ""))} for m in messages]
while True: # sliding window: drop oldest turns until the prompt fits the 1k budget
p = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
e = tok(p, return_tensors="pt", add_special_tokens=False)
if e.input_ids.shape[1] <= max_tok or len(msgs) <= 1:
return e.input_ids, e.attention_mask
msgs = msgs[1:]
def _ntok(text):
return len(tok(text, add_special_tokens=False).input_ids)
def _block_mask(S, bl, dev, dt):
idx = torch.arange(S, device=dev)
allowed = (idx // bl).unsqueeze(0) <= (idx // bl).unsqueeze(1)
return torch.zeros((S, S), device=dev, dtype=dt).masked_fill(~allowed, float("-inf"))[None, None]
def _rep_pen(logits, seen, penalty=REP):
if penalty == 1.0 or seen.numel() == 0:
return logits
u = torch.unique(seen)
s = logits[:, u]
logits[:, u] = torch.where(s > 0, s / penalty, s * penalty)
return logits
def _clip(text):
"""Trim a trailing incomplete sentence so responses never end mid-word."""
t = text.rstrip()
ends = [m.start() for m in re.finditer(r"[.!?](?=\s|$)", t)]
if ends and ends[-1] >= 16:
t = t[:ends[-1] + 1]
return re.sub(r"\s+\d+\.$", "", t).rstrip()
def _render(row, prompt_len):
out = []
for t in row[prompt_len:]:
if t == MASK:
out.append("β–’")
elif t == IM_END:
break
else:
out.append(tok.decode([t]))
return "".join(out)
def _ar_stream(ids, am, max_new, gen=GEN):
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
threading.Thread(target=model.generate, kwargs=dict(
input_ids=ids, attention_mask=am, max_new_tokens=max_new, eos_token_id=IM_END,
pad_token_id=IM_END, streamer=streamer, **gen)).start()
out = ""
for tk_ in streamer:
out += tk_
yield out
@torch.no_grad()
def _diffusion_stream(ids, block_len=32, threshold=0.5, max_new=64):
dev = ids.device; dt = next(model.parameters()).dtype
cur = ids; gen = 0
while gen < max_new:
seq = torch.cat([cur, torch.full((1, block_len), MASK, device=dev)], 1); S = seq.shape[1]
bias = _block_mask(S, block_len, dev, dt)
for _ in range(block_len):
mp = (seq[0] == MASK).nonzero(as_tuple=True)[0]
if mp.numel() == 0:
break
lg = model(input_ids=seq, trimode_bias=bias).logits[0].float()
lgm = _rep_pen(lg[mp], seq[0][seq[0] != MASK])
conf, pred = F.softmax(lgm, -1).max(-1)
take = conf >= threshold
if take.sum() == 0:
take[conf.argmax()] = True
seq[0, mp[take]] = pred[take]
yield _render(seq[0].tolist(), ids.shape[1])
cur = seq; gen += block_len
if (cur[0, -block_len:] == IM_END).any():
break
@torch.no_grad()
def _selfspec_stream(ids, k=8, block_len=32, max_new=48):
dev = ids.device; dt = next(model.parameters()).dtype
cur = ids; gen = 0
while gen < max_new:
n = cur.shape[1]
seq = torch.cat([cur, torch.full((1, k), MASK, device=dev)], 1); S = seq.shape[1]
seen = cur[0]
dl = model(input_ids=seq, trimode_bias=_block_mask(S, block_len, dev, dt)).logits[0].float()
draft = _rep_pen(dl[n:n + k], seen).argmax(-1)
cand = torch.cat([cur, draft.unsqueeze(0)], 1)
al = model(input_ids=cand).logits[0].float()
ar_pred = _rep_pen(al[n - 1:n + k - 1], seen).argmax(-1)
match = (draft == ar_pred)
m = int((~match).float().argmax().item()) if (~match).any() else k
new = torch.cat([draft[:m], ar_pred[m:m + 1]]) if m < k else torch.cat([draft, al[n + k - 1:n + k].argmax(-1)])
cur = torch.cat([cur, new.unsqueeze(0)], 1); gen += new.numel()
yield tok.decode(cur[0, ids.shape[1]:], skip_special_tokens=True)
if IM_END in new.tolist():
break
def _pane(label, speed, text):
sp = f" <span class='dh-s'>{speed}</span>" if speed else ""
return f"<div class='dh-pane'><span class='dh-k'>{label}</span>{sp}<div class='dh-t'>{text}</div></div>"
def _pending(label):
return f"<div class='dh-pane dh-run'><span class='dh-k'>{label}</span> <span class='dh-mask'>β–“β–’β–‘ waiting…</span></div>"
def compare(prompt):
ids, am = _enc([{"role": "user", "content": prompt}])
a = d = s = ""
t0 = time.time()
for a in _ar_stream(ids, am, 48, GEN_GREEDY):
yield _pane("AR", "⟳", a), _pending("BLOCK-DIFFUSION"), _pending("SELF-SPEC")
a = _clip(a); a_s = f"{_ntok(a)/max(1e-9,time.time()-t0):.1f} tok/s"
t0 = time.time()
for d in _diffusion_stream(ids, max_new=48):
yield _pane("AR", a_s, a), _pane("BLOCK-DIFFUSION", "⟳", d), _pending("SELF-SPEC")
d = _clip(d); d_s = f"{_ntok(d)/max(1e-9,time.time()-t0):.1f} tok/s"
t0 = time.time()
for s in _selfspec_stream(ids, max_new=48):
yield _pane("AR", a_s, a), _pane("BLOCK-DIFFUSION", d_s, d), _pane("SELF-SPEC", "⟳", s)
s = _clip(s); s_s = f"{_ntok(s)/max(1e-9,time.time()-t0):.1f} tok/s Β· AR-quality"
yield _pane("AR", a_s, a), _pane("BLOCK-DIFFUSION", d_s, d), _pane("SELF-SPEC", s_s, s)
CSS = """
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;500;600&display=swap');
body, .gradio-container, gradio-app { background:#0b0b0c !important; color:#e9e2d4 !important;
font-family:'IBM Plex Mono', ui-monospace, monospace !important; }
.gradio-container { max-width:920px !important; margin:0 auto !important; }
.dhara-hero { text-align:center; padding:30px 12px 14px; border-bottom:1px solid #26261b; }
.dhara-glyphs { color:#ffb000; opacity:.3; letter-spacing:.55em; font-size:13px; animation:dh-diffuse 3.4s ease-in-out infinite; }
.dhara-title { font-size:56px; font-weight:600; letter-spacing:.2em; color:#ffb000; text-shadow:0 0 22px rgba(255,176,0,.4); margin:4px 0 2px; }
.dhara-tag { color:#8f8a78; font-size:12.5px; letter-spacing:.14em; text-transform:uppercase; }
@keyframes dh-diffuse { 0%,100%{opacity:.14; filter:blur(.7px)} 50%{opacity:.46; filter:blur(0)} }
button.primary, button[variant="primary"], .primary { background:#1c1609 !important; color:#ffb000 !important;
border:1px solid #ffb000 !important; box-shadow:none !important; text-transform:uppercase; letter-spacing:.08em; }
.dh-pane { border:1px solid #26261b; border-left:2px solid #ffb000; border-radius:6px; padding:12px 15px; background:#0f0f10; margin-bottom:10px; }
.dh-k { color:#ffb000; font-weight:600; letter-spacing:.1em; }
.dh-s { color:#8f8a78; font-size:12px; }
.dh-t { white-space:pre-wrap; margin-top:8px; color:#e9e2d4; }
.dh-mask { color:#7a7568; animation:dh-diffuse 1.1s ease-in-out infinite; }
.dh-chat { height:400px !important; max-height:400px !important; overflow-y:auto !important; }
.dh-ctl { gap:14px; align-items:flex-end; margin-bottom:6px; }
.dhara-link { margin-top:8px; font-size:12px; }
.dhara-link a { color:#ffb000; text-decoration:none; border-bottom:1px dotted #ffb000; }
footer { display:none !important; }
"""
HEADER = """
<div class="dhara-hero">
<div class="dhara-glyphs">β–“β–’β–‘ β–’β–“β–‘ β–‘β–’β–“ β–’β–‘β–“ β–“β–‘β–’ β–‘β–“β–’</div>
<div class="dhara-title">dhara</div>
<div class="dhara-tag">tri-mode Β· 250M Β· ar / block-diffusion / self-speculation</div>
<div class="dhara-link">model: <a href="https://huggingface.co/codelion/dhara-250m" target="_blank">codelion/dhara-250m β†—</a></div>
</div>
"""
with gr.Blocks(title="dhara-chat") as demo:
gr.HTML(HEADER)
with gr.Tab("chat"):
with gr.Row(elem_classes="dh-ctl"):
mode = gr.Radio(["AR", "diffusion", "self-spec"], value="AR", label="mode (diffusion shows β–“β–’β–‘ denoising; self-spec = AR-quality)", scale=3)
max_new = gr.Number(value=80, label="max tokens", precision=0, minimum=16, maximum=256, scale=1)
temp = gr.Number(value=0.7, label="temp (AR)", minimum=0.1, maximum=1.5, scale=1)
chatbot = gr.Chatbot(height=400, show_label=False, elem_classes="dh-chat")
with gr.Row():
msg = gr.Textbox(placeholder="Message dhara…", show_label=False, container=False, scale=7)
send = gr.Button("send", variant="primary", scale=1)
clear = gr.Button("new chat", scale=1)
gr.Examples(["Write a short paragraph about the ocean.",
"Explain what a neural network is in simple terms.",
"What is exercise good for?",
"Why is sleep important?"],
inputs=msg, outputs=chatbot, fn=lambda _q: [], run_on_click=True,
cache_examples=False, label="try a prompt (starts a new chat)")
def _user(message, history):
return "", (history or []) + [{"role": "user", "content": message}]
def _bot(history, mode, max_new, temperature):
ids, am = _enc(history)
last = ""
if mode == "diffusion":
for last in _diffusion_stream(ids, max_new=int(max_new)):
yield history + [{"role": "assistant", "content": last}]
elif mode == "self-spec":
for last in _selfspec_stream(ids, max_new=int(max_new)):
yield history + [{"role": "assistant", "content": last}]
else:
for last in _ar_stream(ids, am, int(max_new), dict(GEN, temperature=float(temperature))):
yield history + [{"role": "assistant", "content": last}]
yield history + [{"role": "assistant", "content": _clip(last)}]
msg.submit(_user, [msg, chatbot], [msg, chatbot]).then(_bot, [chatbot, mode, max_new, temp], chatbot)
send.click(_user, [msg, chatbot], [msg, chatbot]).then(_bot, [chatbot, mode, max_new, temp], chatbot)
clear.click(lambda: ([], ""), None, [chatbot, msg])
with gr.Tab("tri-mode compare"):
gr.Markdown("Same prompt, three decoding modes β€” watch **AR** type, **block-diffusion** denoise `β–“β–’β–‘`, and **self-spec** jump in accepted spans. \n<span style='color:#8f8a78'>config: greedy Β· repetition-penalty 1.3 Β· 48 tokens/mode Β· clipped to last full sentence</span>")
inp = gr.Textbox(label="prompt", value="Write a short paragraph about the ocean.")
gr.Examples(["Write a short paragraph about the ocean.",
"Explain what a neural network is in simple terms.",
"What is exercise good for?"],
inputs=inp)
btn = gr.Button("run all 3 modes", variant="primary")
o1 = gr.HTML(); o2 = gr.HTML(); o3 = gr.HTML()
btn.click(compare, inp, [o1, o2, o3])
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860,
theme=gr.themes.Base(font=gr.themes.GoogleFont("IBM Plex Mono")), css=CSS)