JoeStrout's picture
Upload app.py with huggingface_hub
b59c36e verified
#!/usr/bin/env python3
"""Gradio Space: MiniScript Code Helper (LoRA + RAG).
Loads the fine-tuned Qwen2.5-Coder-7B-Instruct LoRA adapter and a ChromaDB
vector index built from MiniScript documentation, then serves a chat interface.
"""
import os
import re
os.environ.setdefault("USE_TF", "0")
import chromadb
import gradio as gr
import torch
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct"
ADAPTER_REPO = "JoeStrout/miniscript-code-helper-lora"
RAG_DIR = "./RAG_sources"
DB_DIR = "./chroma_db"
COLLECTION = "miniscript_docs"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
TOP_K = 5
MAX_NEW_TOKENS = 1024
MAX_CHUNK_CHARS = 1500
BASE_SYSTEM_PROMPT = "You are a helpful assistant specializing in MiniScript programming."
# ---------------------------------------------------------------------------
# RAG index builder (inline so app is self-contained)
# ---------------------------------------------------------------------------
def strip_leanpub(text: str) -> str:
lines = text.splitlines()
cleaned = []
for line in lines:
if re.match(r'^\s*\{(chapterHead|width|i:|caption|pagebreak|startingPageNum)', line):
m = re.search(r'\{caption:\s*"([^"]+)"\}', line)
if m:
cleaned.append(f"[{m.group(1)}]")
continue
if re.match(r'^\s*!\[.*\]\(.*\)\s*$', line):
continue
line = re.sub(r'^([QADX])>\s?', '', line)
cleaned.append(line)
return '\n'.join(cleaned)
def split_long_chunk(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list:
if len(text) <= max_chars:
return [text]
paragraphs = re.split(r'\n\n+', text)
chunks, current = [], ""
for para in paragraphs:
if current and len(current) + len(para) + 2 > max_chars:
chunks.append(current.strip())
current = para
else:
current = current + "\n\n" + para if current else para
if current.strip():
chunks.append(current.strip())
return chunks
def chunk_document(text: str, filename: str) -> list:
is_txt = filename.endswith('.txt')
if is_txt:
text = strip_leanpub(text)
lines = text.splitlines()
chunks, current_section, current_lines = [], filename, []
def flush():
body = '\n'.join(current_lines).strip()
if not body:
return
for part in split_long_chunk(body):
if part.strip():
chunks.append({"text": part, "source": filename, "section": current_section})
for line in lines:
heading = None
if is_txt:
m = re.match(r'^(#{1,4})\s+(.*)', line)
if m:
heading = m.group(2).strip()
elif re.match(r'^#{1,4}\s', line):
heading = re.sub(r'^#+\s*', '', line).strip()
if heading:
flush()
current_section = heading
current_lines = []
else:
current_lines.append(line)
flush()
return chunks
def build_rag_index():
print(f"Building ChromaDB index from {RAG_DIR}/ ...")
embedding_fn = SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL)
client = chromadb.PersistentClient(path=DB_DIR)
existing = [c.name for c in client.list_collections()]
if COLLECTION in existing:
col = client.get_collection(name=COLLECTION, embedding_function=embedding_fn)
print(f" Reusing existing collection ({col.count()} chunks)")
return col
col = client.create_collection(
name=COLLECTION,
embedding_function=embedding_fn,
metadata={"hnsw:space": "cosine"},
)
source_files = sorted(f for f in os.listdir(RAG_DIR) if f.endswith(('.md', '.txt')))
all_chunks = []
for fname in source_files:
with open(os.path.join(RAG_DIR, fname), encoding='utf-8') as f:
text = f.read()
chunks = chunk_document(text, fname)
print(f" {fname}: {len(chunks)} chunks")
all_chunks.extend(chunks)
BATCH = 100
for i in range(0, len(all_chunks), BATCH):
batch = all_chunks[i:i + BATCH]
col.add(
ids=[f"chunk_{i + j}" for j in range(len(batch))],
documents=[c["text"] for c in batch],
metadatas=[{"source": c["source"], "section": c["section"]} for c in batch],
)
print(f" Indexed {col.count()} chunks total.")
return col
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_model():
print(f"Loading tokenizer from {ADAPTER_REPO} ...")
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO)
print(f"Loading base model {BASE_MODEL} in 4-bit ...")
bnb_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_cfg,
device_map="auto",
)
print(f"Loading LoRA adapter from {ADAPTER_REPO} ...")
model = PeftModel.from_pretrained(base, ADAPTER_REPO)
model.eval()
print("Model ready!")
return tokenizer, model
# ---------------------------------------------------------------------------
# Startup
# ---------------------------------------------------------------------------
collection = build_rag_index()
tokenizer, model = load_model()
# ---------------------------------------------------------------------------
# Chat logic
# ---------------------------------------------------------------------------
def build_system_prompt(results: dict) -> str:
if not results or not results["documents"] or not results["documents"][0]:
return BASE_SYSTEM_PROMPT
parts = []
for doc, meta in zip(results["documents"][0], results["metadatas"][0]):
parts.append(f"[Source: {meta['source']}, Section: {meta['section']}]\n{doc}")
context = "\n\n".join(parts)
return (
f"{BASE_SYSTEM_PROMPT}\n\n"
f"Use the following reference material to help answer the user's question:\n\n"
f"{context}"
)
def chat(message: str, history: list) -> str:
results = collection.query(query_texts=[message], n_results=TOP_K)
system_prompt = build_system_prompt(results)
messages = [{"role": "system", "content": system_prompt}]
# Gradio 5: history is list of {"role": ..., "content": ...} dicts
for entry in history:
messages.append({"role": entry["role"], "content": entry["content"]})
messages.append({"role": "user", "content": message})
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS)
response = tokenizer.decode(output[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
return response
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
demo = gr.ChatInterface(
fn=chat,
title="MiniScript Code Helper",
description=(
"Ask questions about the [MiniScript](https://miniscript.org) programming language. "
"Powered by a fine-tuned Qwen2.5-Coder-7B-Instruct model with RAG over MiniScript documentation."
),
examples=[
"How do I define a function in MiniScript?",
"How do I iterate over a list?",
"What is the difference between `and` and `&&` in MiniScript?",
"How do I read a file in MiniScript?",
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()