Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| import tarfile | |
| import stat | |
| import threading | |
| import subprocess | |
| from pathlib import Path | |
| from typing import List, Dict, Optional | |
| import requests | |
| import gradio as gr | |
| # ============================================================================= | |
| # UTF-8 everywhere | |
| # ============================================================================= | |
| os.environ.setdefault("PYTHONIOENCODING", "utf-8") | |
| os.environ.setdefault("LANG", "C.UTF-8") | |
| os.environ.setdefault("LC_ALL", "C.UTF-8") | |
| # ============================================================================= | |
| # Model on HF (GGUF) | |
| # ============================================================================= | |
| HF_REPO = os.environ.get("HF_REPO", "staeiou/bartleby-qwen3-0.6b") | |
| HF_FILE = os.environ.get("HF_FILE", "bartleby-qwen3-0.6b.Q4_K_M.gguf") | |
| # ============================================================================= | |
| # llama.cpp pin (do not follow latest) | |
| # ============================================================================= | |
| LLAMA_CPP_RELEASE = os.environ.get("LLAMA_CPP_RELEASE", "b8181").strip() | |
| LLAMA_CPP_ASSET = os.environ.get( | |
| "LLAMA_CPP_ASSET", | |
| f"llama-{LLAMA_CPP_RELEASE}-bin-ubuntu-x64.tar.gz", | |
| ).strip() | |
| # ============================================================================= | |
| # llama.cpp server settings | |
| # ============================================================================= | |
| HOST = os.environ.get("LLAMA_HOST", "127.0.0.1") | |
| PORT = int(os.environ.get("LLAMA_PORT", "8080")) | |
| BASE_URL = f"http://{HOST}:{PORT}" | |
| CTX_SIZE = int(os.environ.get("LLAMA_CTX", "1024")) | |
| N_THREADS = int(os.environ.get("LLAMA_THREADS", "2")) | |
| N_THREADS_BATCH = int(os.environ.get("LLAMA_THREADS_BATCH", str(N_THREADS))) | |
| PARALLEL = int(os.environ.get("LLAMA_PARALLEL", "1")) | |
| THREADS_HTTP = int(os.environ.get("LLAMA_THREADS_HTTP", "1")) | |
| BATCH_SIZE = int(os.environ.get("LLAMA_BATCH", "128")) | |
| UBATCH_SIZE = int(os.environ.get("LLAMA_UBATCH", "64")) | |
| USE_MLOCK = os.environ.get("LLAMA_MLOCK", "1") == "1" | |
| USE_CONT_BATCHING = os.environ.get("LLAMA_CONT_BATCHING", "0") == "1" | |
| SYSTEM_PROMPT_DEFAULT = os.environ.get("SYSTEM_PROMPT", "") | |
| # Prefer /data if present (persistent), else /tmp | |
| DATA_DIR = Path("/data") if Path("/data").exists() else Path("/tmp") | |
| HF_HOME = Path(os.environ.get("HF_HOME", str(DATA_DIR / "hf_home"))) | |
| os.environ["HF_HOME"] = str(HF_HOME) | |
| LLAMA_DIR = Path(os.environ.get("LLAMA_BIN_DIR", str(DATA_DIR / "llama_cpp_bin"))) | |
| LLAMA_DIR.mkdir(parents=True, exist_ok=True) | |
| # ============================================================================= | |
| # GoatCounter | |
| # Frontend count.js for browser-side events | |
| # Backend API only for server-side events | |
| # ============================================================================= | |
| GOATCOUNTER_CODE = os.environ.get("GOATCOUNTER_CODE", "").strip() | |
| GOATCOUNTER_API_TOKEN = os.environ.get("GOATCOUNTER_API_TOKEN", "").strip() | |
| GOATCOUNTER_SITE = f"https://{GOATCOUNTER_CODE}.goatcounter.com" if GOATCOUNTER_CODE else "" | |
| GOATCOUNTER_API_URL = f"{GOATCOUNTER_SITE}/api/v0/count" if GOATCOUNTER_SITE else "" | |
| # ============================================================================= | |
| # UI strings | |
| # ============================================================================= | |
| LOADING_TEXT = "⏳ Loading model, will take about 1 minute…" | |
| ERROR_TEXT = "Sorry — the model failed to respond. Please try again." | |
| # ============================================================================= | |
| # GoatCounter frontend + UI guard JS | |
| # - manual app/start | |
| # - manual turn milestones | |
| # - blocks Enter / Send while busy | |
| # - clears the textbox in a separate Gradio step so submit progress does not target it | |
| # ============================================================================= | |
| _GOAT_SETTINGS_JSON = json.dumps({"no_onload": True, "no_events": True}) | |
| HEAD_HTML = f""" | |
| <script> | |
| window.goatcounter = window.goatcounter || {{}}; | |
| Object.assign(window.goatcounter, {json.dumps({"no_onload": True, "no_events": True})}); | |
| </script> | |
| {f'<script data-goatcounter="{GOATCOUNTER_SITE}/count" async src="https://gc.zgo.at/count.js"></script>' if GOATCOUNTER_CODE else ''} | |
| <script> | |
| (() => {{ | |
| const GOAT_ENABLED = {str(bool(GOATCOUNTER_CODE)).lower()}; | |
| const LOADING_TEXT = {json.dumps(LOADING_TEXT)}; | |
| const ERROR_TEXT = {json.dumps(ERROR_TEXT)}; | |
| const state = {{ | |
| phase: "idle", | |
| turnCount: 0, | |
| sessionStarted: false, | |
| lastCountedAssistantText: "" | |
| }}; | |
| function gcReady() {{ | |
| return GOAT_ENABLED && window.goatcounter && typeof window.goatcounter.count === "function"; | |
| }} | |
| function gcCount(path, title) {{ | |
| if (!gcReady()) return; | |
| try {{ | |
| window.goatcounter.count({{ | |
| path: path, | |
| title: title || path, | |
| event: true | |
| }}); | |
| }} catch (_e) {{}} | |
| }} | |
| function latestAssistantText() {{ | |
| const chatWrap = document.querySelector("#chat_wrap"); | |
| if (!chatWrap) return ""; | |
| const botRows = chatWrap.querySelectorAll(".bot-row"); | |
| if (!botRows.length) return ""; | |
| const row = botRows[botRows.length - 1]; | |
| return (row.innerText || row.textContent || "").trim(); | |
| }} | |
| function sendButton() {{ | |
| return document.querySelector("#send_btn button"); | |
| }} | |
| function msgTextarea() {{ | |
| return document.querySelector("#msg_box textarea"); | |
| }} | |
| function statusInput() {{ | |
| return ( | |
| document.querySelector("#status_box textarea") || | |
| document.querySelector("#status_box input") | |
| ); | |
| }} | |
| function setPhase(next) {{ | |
| next = (next || "idle").trim().toLowerCase(); | |
| const prev = state.phase; | |
| state.phase = next; | |
| const btn = sendButton(); | |
| if (btn) btn.disabled = (next !== "idle"); | |
| if ((prev === "busy" || prev === "cold") && next === "idle") {{ | |
| const txt = latestAssistantText(); | |
| if ( | |
| txt && | |
| txt !== LOADING_TEXT && | |
| txt !== ERROR_TEXT && | |
| txt !== state.lastCountedAssistantText | |
| ) {{ | |
| state.turnCount += 1; | |
| state.lastCountedAssistantText = txt; | |
| gcCount("session/turn/completed", "Completed turn"); | |
| gcCount(`session/turn/${{state.turnCount}}`, `Turn ${{state.turnCount}}`); | |
| }} | |
| }} | |
| }} | |
| function syncPhaseFromStatus() {{ | |
| const el = statusInput(); | |
| if (!el) return; | |
| setPhase((el.value || "idle")); | |
| }} | |
| function bindGuards() {{ | |
| const btn = sendButton(); | |
| const box = msgTextarea(); | |
| if (btn && !btn.dataset.bartlebyBound) {{ | |
| btn.dataset.bartlebyBound = "1"; | |
| btn.addEventListener("click", (e) => {{ | |
| if (state.phase !== "idle") {{ | |
| e.preventDefault(); | |
| e.stopPropagation(); | |
| if (e.stopImmediatePropagation) e.stopImmediatePropagation(); | |
| return false; | |
| }} | |
| }}, true); | |
| }} | |
| if (box && !box.dataset.bartlebyBound) {{ | |
| box.dataset.bartlebyBound = "1"; | |
| box.addEventListener("keydown", (e) => {{ | |
| if (e.key !== "Enter" || e.shiftKey) return; | |
| if (state.phase !== "idle") {{ | |
| e.preventDefault(); | |
| e.stopPropagation(); | |
| if (e.stopImmediatePropagation) e.stopImmediatePropagation(); | |
| return false; | |
| }} | |
| }}, true); | |
| }} | |
| }} | |
| function start() {{ | |
| const waitForGC = setInterval(() => {{ | |
| if (!gcReady()) return; | |
| clearInterval(waitForGC); | |
| if (!state.sessionStarted) {{ | |
| state.sessionStarted = true; | |
| gcCount("app/start", "App start"); | |
| }} | |
| }}, 150); | |
| setTimeout(() => clearInterval(waitForGC), 15000); | |
| bindGuards(); | |
| syncPhaseFromStatus(); | |
| setInterval(bindGuards, 1000); | |
| setInterval(syncPhaseFromStatus, 120); | |
| }} | |
| if (document.readyState === "loading") {{ | |
| document.addEventListener("DOMContentLoaded", start, {{ once: true }}); | |
| }} else {{ | |
| start(); | |
| }} | |
| }})(); | |
| </script> | |
| """ | |
| # ============================================================================= | |
| # CSS | |
| # ============================================================================= | |
| CUSTOM_CSS = r""" | |
| footer { visibility: hidden; } | |
| html, body { | |
| height: 100%; | |
| margin: 0; | |
| overflow: hidden !important; | |
| } | |
| .gradio-container { | |
| height: 100dvh !important; | |
| max-height: 100dvh !important; | |
| overflow: hidden !important; | |
| } | |
| #app_root { | |
| position: fixed; | |
| inset: 0; | |
| display: flex; | |
| flex-direction: column; | |
| overflow: hidden !important; | |
| } | |
| #chat_wrap { | |
| flex: 1 1 auto; | |
| min-height: 0; | |
| overflow: hidden !important; | |
| } | |
| #chat_wrap .gradio-chatbot, | |
| #chat_wrap .gr-chatbot, | |
| #chat_wrap [data-testid="chatbot"], | |
| #chat_wrap #chatbot { | |
| height: 100% !important; | |
| max-height: none !important; | |
| padding-top: 50px !important; | |
| box-sizing: border-box !important; | |
| } | |
| #input_row { | |
| flex: 0 0 auto; | |
| padding: 6px 0 6px 0; | |
| } | |
| #msg_box { | |
| border: none !important; | |
| box-shadow: none !important; | |
| } | |
| #msg_box > div { | |
| border-radius: 8px !important; | |
| box-shadow: none !important; | |
| overflow: hidden !important; | |
| } | |
| #msg_box > div > div { | |
| border: none !important; | |
| box-shadow: none !important; | |
| background: transparent !important; | |
| } | |
| #msg_box textarea { | |
| min-height: 2.6em !important; | |
| max-height: 2.6em !important; | |
| height: 2.6em !important; | |
| line-height: 1.25 !important; | |
| overflow: hidden !important; | |
| resize: none !important; | |
| border: none !important; | |
| outline: none !important; | |
| box-shadow: none !important; | |
| background: transparent !important; | |
| } | |
| #send_btn button { | |
| min-height: 2.6em !important; | |
| height: 2.6em !important; | |
| padding-top: 0em !important; | |
| padding-bottom: 0em !important; | |
| } | |
| #params_bar { | |
| flex: 0 0 auto; | |
| } | |
| #params_bar .gr-accordion-content, | |
| #params_bar .accordion-content { | |
| max-height: 45dvh; | |
| overflow: auto; | |
| } | |
| #status_box { | |
| display: none !important; | |
| } | |
| @media (max-width: 768px) { | |
| .gradio-container { padding: 8px !important; } | |
| #send_btn { | |
| flex: 0 0 5.5rem !important; | |
| min-width: 5.5rem !important; | |
| max-width: 5.5rem !important; | |
| } | |
| #send_btn button { | |
| width: 100% !important; | |
| } | |
| } | |
| @media (min-width: 769px) { | |
| .gradio-container { padding: 12px !important; } | |
| } | |
| #chat_wrap .message-row.user-row { justify-content: flex-start !important; } | |
| #chat_wrap .message-row.bot-row { justify-content: flex-end !important; } | |
| @media (prefers-color-scheme: light) { | |
| #msg_box > div { | |
| border: 0.5px solid #FFFFFF !important; | |
| box-shadow: none !important; | |
| } | |
| #msg_box:hover > div { | |
| border-color: #5d5d5d !important; | |
| } | |
| #msg_box:focus-within > div { | |
| border-color: #4d4d4d !important; | |
| box-shadow: none !important; | |
| } | |
| #msg_box:focus-within { | |
| outline: 2px solid rgba(77,77,77,0.18) !important; | |
| outline-offset: 2px !important; | |
| } | |
| #msg_box textarea:focus { | |
| border: none !important; | |
| outline: none !important; | |
| box-shadow: none !important; | |
| } | |
| } | |
| @media (prefers-color-scheme: dark) { | |
| #msg_box > div { | |
| border: 1.5px solid rgba(255,255,255,0.22) !important; | |
| box-shadow: none !important; | |
| } | |
| #msg_box:hover > div { | |
| border-color: rgba(255,255,255,0.32) !important; | |
| } | |
| #msg_box:focus-within > div { | |
| border-color: rgba(255,255,255,0.45) !important; | |
| } | |
| #msg_box:focus-within { | |
| outline: 2px solid rgba(255,255,255,0.14) !important; | |
| outline-offset: 2px !important; | |
| } | |
| #msg_box textarea:focus { | |
| border: none !important; | |
| outline: none !important; | |
| box-shadow: none !important; | |
| } | |
| } | |
| """ | |
| # ============================================================================= | |
| # Mobile focus guard JS | |
| # ============================================================================= | |
| FOCUS_GUARD_JS = r""" | |
| () => { | |
| const isMobile = /Mobi|Android|iPhone|iPad|iPod/i.test(navigator.userAgent); | |
| if (!isMobile) return; | |
| const inputSel = "#msg_box textarea"; | |
| const chatSel = "#chat_wrap"; | |
| let lastTouch = 0; | |
| const arm = () => { | |
| const input = document.querySelector(inputSel); | |
| const chat = document.querySelector(chatSel); | |
| if (!input || !chat) return; | |
| input.addEventListener("touchstart", () => { lastTouch = Date.now(); }, { passive: true }); | |
| const blurIfUnintended = () => { | |
| const recent = (Date.now() - lastTouch) < 600; | |
| if (!recent && document.activeElement === input) input.blur(); | |
| }; | |
| const mo = new MutationObserver(() => blurIfUnintended()); | |
| mo.observe(chat, { childList: true, subtree: true, characterData: true }); | |
| document.addEventListener("focusin", (e) => { | |
| if (e.target === input) blurIfUnintended(); | |
| }, true); | |
| }; | |
| arm(); | |
| setTimeout(arm, 500); | |
| setTimeout(arm, 1500); | |
| } | |
| """ | |
| # ============================================================================= | |
| # Server lifecycle globals | |
| # ============================================================================= | |
| _server_lock = threading.Lock() | |
| _server_proc: Optional[subprocess.Popen] = None | |
| LLAMA_SERVER: Optional[Path] = None | |
| SERVER_MODEL_ID: Optional[str] = None | |
| SESSION = requests.Session() | |
| # ============================================================================= | |
| # Types | |
| # ============================================================================= | |
| ChatHistory = List[Dict[str, str]] | |
| # ============================================================================= | |
| # Helpers | |
| # ============================================================================= | |
| def goat_backend_event(path: str, title: str, request: Optional[gr.Request] = None) -> None: | |
| """ | |
| Backend GoatCounter event for server-known facts only. | |
| """ | |
| if not GOATCOUNTER_API_URL or not GOATCOUNTER_API_TOKEN: | |
| return | |
| hit = { | |
| "path": path, | |
| "title": title or path, | |
| "event": True, | |
| } | |
| if request is not None: | |
| try: | |
| headers = {str(k).lower(): str(v) for k, v in dict(request.headers).items()} | |
| except Exception: | |
| headers = {} | |
| ua = (headers.get("user-agent") or "").strip() | |
| if ua: | |
| hit["user_agent"] = ua | |
| try: | |
| session = request.session_hash | |
| if session: | |
| hit["session"] = str(session) | |
| except Exception: | |
| pass | |
| payload = {"hits": [hit]} | |
| if "session" not in hit: | |
| payload["no_sessions"] = True | |
| try: | |
| r = SESSION.post( | |
| GOATCOUNTER_API_URL, | |
| headers={ | |
| "Authorization": f"Bearer {GOATCOUNTER_API_TOKEN}", | |
| "Content-Type": "application/json; charset=utf-8", | |
| }, | |
| json=payload, | |
| timeout=5, | |
| ) | |
| if r.status_code not in (200, 202): | |
| print(f"[goatcounter] {r.status_code}: {r.text[:1000]}", flush=True) | |
| except Exception as e: | |
| print(f"[goatcounter] failed to send backend event {path!r}: {e}", flush=True) | |
| def _make_executable(path: Path) -> None: | |
| st = os.stat(path) | |
| os.chmod(path, st.st_mode | stat.S_IEXEC) | |
| def _safe_extract_tar(tf: tarfile.TarFile, out_dir: Path) -> None: | |
| try: | |
| tf.extractall(path=out_dir, filter="data") # py3.12+ | |
| except TypeError: | |
| tf.extractall(path=out_dir) | |
| def _truncate(s: str, n: int) -> str: | |
| s = s if isinstance(s, str) else str(s) | |
| return s if len(s) <= n else s[:n] | |
| def _server_process_alive() -> bool: | |
| return _server_proc is not None and _server_proc.poll() is None | |
| # ============================================================================= | |
| # llama.cpp bootstrap | |
| # ============================================================================= | |
| def _download_llama_cpp_release() -> Path: | |
| release_dir = LLAMA_DIR / LLAMA_CPP_RELEASE | |
| expected = release_dir / "llama-server" | |
| if expected.is_file(): | |
| _make_executable(expected) | |
| return expected | |
| release_dir.mkdir(parents=True, exist_ok=True) | |
| asset_url = ( | |
| f"https://github.com/ggml-org/llama.cpp/releases/download/" | |
| f"{LLAMA_CPP_RELEASE}/{LLAMA_CPP_ASSET}" | |
| ) | |
| tar_path = release_dir / LLAMA_CPP_ASSET | |
| print(f"[app] Downloading pinned llama.cpp release: {asset_url}", flush=True) | |
| with SESSION.get(asset_url, stream=True, timeout=180) as r: | |
| r.raise_for_status() | |
| with open(tar_path, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=1024 * 1024): | |
| if chunk: | |
| f.write(chunk) | |
| print("[app] Extracting llama.cpp tarball...", flush=True) | |
| with tarfile.open(tar_path, "r:gz") as tf: | |
| _safe_extract_tar(tf, release_dir) | |
| candidates = list(release_dir.rglob("llama-server")) | |
| if not candidates: | |
| raise RuntimeError( | |
| f"Downloaded {LLAMA_CPP_ASSET} from {LLAMA_CPP_RELEASE} but could not find llama-server." | |
| ) | |
| server_bin = candidates[0] | |
| _make_executable(server_bin) | |
| print(f"[app] llama-server path: {server_bin}", flush=True) | |
| return server_bin | |
| def _wait_for_health(timeout_s: int = 180) -> None: | |
| deadline = time.time() + timeout_s | |
| last_err = None | |
| while time.time() < deadline: | |
| try: | |
| r = SESSION.get(f"{BASE_URL}/health", timeout=2) | |
| if r.status_code == 200: | |
| return | |
| last_err = f"health status {r.status_code}" | |
| except Exception as e: | |
| last_err = str(e) | |
| time.sleep(0.35) | |
| raise RuntimeError(f"llama-server not healthy in time. Last error: {last_err}") | |
| def _warmup() -> None: | |
| try: | |
| payload = { | |
| "model": SERVER_MODEL_ID or HF_REPO, | |
| "messages": [{"role": "user", "content": "hi"}], | |
| "temperature": 0.0, | |
| "top_p": 1.0, | |
| "max_tokens": 4, | |
| "stream": False, | |
| } | |
| SESSION.post(f"{BASE_URL}/v1/chat/completions", json=payload, timeout=60) | |
| except Exception: | |
| pass | |
| def ensure_server_started(request: Optional[gr.Request] = None) -> None: | |
| global _server_proc, LLAMA_SERVER, SERVER_MODEL_ID | |
| cold_start = False | |
| with _server_lock: | |
| if _server_process_alive() and SERVER_MODEL_ID is not None: | |
| return | |
| cold_start = not _server_process_alive() | |
| LLAMA_SERVER = _download_llama_cpp_release() | |
| HF_HOME.mkdir(parents=True, exist_ok=True) | |
| cmd = [ | |
| str(LLAMA_SERVER), | |
| "--host", HOST, | |
| "--port", str(PORT), | |
| "--no-webui", | |
| "--jinja", | |
| "--ctx-size", str(CTX_SIZE), | |
| "--threads", str(N_THREADS), | |
| "--threads-batch", str(N_THREADS_BATCH), | |
| "--threads-http", str(THREADS_HTTP), | |
| "--parallel", str(PARALLEL), | |
| "--batch-size", str(BATCH_SIZE), | |
| "--ubatch-size", str(UBATCH_SIZE), | |
| "-hf", HF_REPO, | |
| "--hf-file", HF_FILE, | |
| ] | |
| if USE_MLOCK: | |
| cmd.append("--mlock") | |
| if USE_CONT_BATCHING: | |
| cmd.append("--cont-batching") | |
| print("[app] Starting llama-server with:", flush=True) | |
| print(" " + " ".join(cmd), flush=True) | |
| env = os.environ.copy() | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| env["LANG"] = env.get("LANG", "C.UTF-8") | |
| env["LC_ALL"] = env.get("LC_ALL", "C.UTF-8") | |
| _server_proc = subprocess.Popen(cmd, stdout=None, stderr=None, env=env) | |
| _wait_for_health(timeout_s=180) | |
| try: | |
| j = SESSION.get(f"{BASE_URL}/v1/models", timeout=5).json() | |
| SERVER_MODEL_ID = j["data"][0]["id"] | |
| except Exception: | |
| SERVER_MODEL_ID = HF_REPO | |
| print(f"[app] llama-server healthy. model_id={SERVER_MODEL_ID}", flush=True) | |
| _warmup() | |
| if cold_start: | |
| goat_backend_event("server/cold-start", "llama-server started", request=request) | |
| def stream_chat( | |
| messages, | |
| temperature: float, | |
| top_p: float, | |
| max_tokens: int, | |
| request: Optional[gr.Request] = None, | |
| ): | |
| payload = { | |
| "model": SERVER_MODEL_ID or HF_REPO, | |
| "messages": messages, | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "max_tokens": int(max_tokens), | |
| "stream": True, | |
| } | |
| headers = { | |
| "Accept": "text/event-stream", | |
| "Content-Type": "application/json; charset=utf-8", | |
| "Connection": "keep-alive", | |
| } | |
| last_err: Optional[Exception] = None | |
| for _attempt in range(10): | |
| try: | |
| with SESSION.post( | |
| f"{BASE_URL}/v1/chat/completions", | |
| json=payload, | |
| stream=True, | |
| timeout=600, | |
| headers=headers, | |
| ) as r: | |
| if r.status_code != 200: | |
| body = r.text[:2000] | |
| raise requests.exceptions.HTTPError( | |
| f"{r.status_code} from llama-server: {body}", | |
| response=r, | |
| ) | |
| for raw in r.iter_lines(decode_unicode=False, chunk_size=8192): | |
| if not raw: | |
| continue | |
| line = raw.decode("utf-8", errors="replace") | |
| if not line.startswith("data: "): | |
| continue | |
| data = line[6:].strip() | |
| if data == "[DONE]": | |
| return | |
| try: | |
| obj = json.loads(data) | |
| except Exception: | |
| continue | |
| delta = obj["choices"][0].get("delta") or {} | |
| tok = delta.get("content") | |
| if tok: | |
| yield tok | |
| return | |
| except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e: | |
| last_err = e | |
| time.sleep(0.35) | |
| try: | |
| ensure_server_started(request=request) | |
| except Exception: | |
| pass | |
| if last_err: | |
| raise last_err | |
| # ============================================================================= | |
| # Gradio handlers | |
| # - never target the textbox from the main submit step | |
| # - hidden status_box drives frontend busy guard + Send button disabled state | |
| # ============================================================================= | |
| def on_user_submit(user_text: str, history: ChatHistory, busy: bool): | |
| history = history or [] | |
| if busy: | |
| return history, "busy", True | |
| user_text = (user_text or "").strip() | |
| if not user_text: | |
| return history, "idle", False | |
| user_text = _truncate(user_text, 2000) | |
| is_cold = (not _server_process_alive()) or (SERVER_MODEL_ID is None) | |
| assistant_placeholder = LOADING_TEXT if is_cold else "" | |
| history = history + [ | |
| {"role": "user", "content": user_text}, | |
| {"role": "assistant", "content": assistant_placeholder}, | |
| ] | |
| return history, ("cold" if is_cold else "busy"), True | |
| def clear_message_box(): | |
| return "" | |
| def on_bot_respond( | |
| history: ChatHistory, | |
| system_message: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| busy: bool, | |
| request: gr.Request, | |
| ): | |
| history = history or [] | |
| if not busy or len(history) < 2 or history[-1].get("role") != "assistant": | |
| yield history, "idle", False | |
| return | |
| user_msg = history[-2].get("content", "") | |
| try: | |
| ensure_server_started(request=request) | |
| msgs = [] | |
| sys = (system_message or "").strip() | |
| if sys: | |
| msgs.append({"role": "system", "content": sys}) | |
| msgs.append({"role": "user", "content": user_msg}) | |
| out = "" | |
| first_token_seen = False | |
| for tok in stream_chat( | |
| msgs, | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| max_tokens=int(max_tokens), | |
| request=request, | |
| ): | |
| if not first_token_seen: | |
| first_token_seen = True | |
| out = "" | |
| out += tok | |
| history[-1]["content"] = out | |
| yield history, "busy", True | |
| if not first_token_seen: | |
| history[-1]["content"] = ERROR_TEXT | |
| yield history, "idle", False | |
| return | |
| yield history, "idle", False | |
| except Exception as e: | |
| print(f"[app] Generation failed: {e}", flush=True) | |
| history[-1]["content"] = ERROR_TEXT | |
| goat_backend_event("chat/response-error", "Generation error", request=request) | |
| yield history, "idle", False | |
| # ============================================================================= | |
| # UI | |
| # ============================================================================= | |
| with gr.Blocks(title="BartlebyGPT", fill_height=True, head=HEAD_HTML) as demo: | |
| busy_state = gr.State(False) | |
| with gr.Column(elem_id="app_root"): | |
| with gr.Column(elem_id="chat_wrap"): | |
| chatbot = gr.Chatbot( | |
| value=[], | |
| show_label=False, | |
| autoscroll=True, | |
| height="100%", | |
| elem_id="chatbot", | |
| ) | |
| with gr.Row(elem_id="input_row"): | |
| msg = gr.Textbox( | |
| placeholder="What do you want?", | |
| show_label=False, | |
| lines=1, | |
| max_lines=1, | |
| autofocus=False, | |
| elem_id="msg_box", | |
| scale=10, | |
| ) | |
| send = gr.Button("Send", variant="primary", elem_id="send_btn", scale=1) | |
| with gr.Accordion("Params", open=False, elem_id="params_bar"): | |
| system_box = gr.Textbox(value=SYSTEM_PROMPT_DEFAULT, label="System message", lines=2) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(1, 512, value=256, step=1, label="Max new tokens") | |
| temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| status_box = gr.Textbox(value="idle", show_label=False, elem_id="status_box") | |
| msg.submit( | |
| on_user_submit, | |
| [msg, chatbot, busy_state], | |
| [chatbot, status_box, busy_state], | |
| queue=False, | |
| ).then( | |
| clear_message_box, | |
| None, | |
| [msg], | |
| queue=False, | |
| ).then( | |
| on_bot_respond, | |
| [chatbot, system_box, max_tokens, temperature, top_p, busy_state], | |
| [chatbot, status_box, busy_state], | |
| ) | |
| send.click( | |
| on_user_submit, | |
| [msg, chatbot, busy_state], | |
| [chatbot, status_box, busy_state], | |
| queue=False, | |
| ).then( | |
| clear_message_box, | |
| None, | |
| [msg], | |
| queue=False, | |
| ).then( | |
| on_bot_respond, | |
| [chatbot, system_box, max_tokens, temperature, top_p, busy_state], | |
| [chatbot, status_box, busy_state], | |
| ) | |
| demo.queue(default_concurrency_limit=1, max_size=128) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("PORT", "7860")), | |
| css=CUSTOM_CSS, | |
| js=FOCUS_GUARD_JS, | |
| ) | |