Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """Gradio Space: MiniScript Code Helper (LoRA + RAG). | |
| Loads the fine-tuned Qwen2.5-Coder-7B-Instruct LoRA adapter and a ChromaDB | |
| vector index built from MiniScript documentation, then serves a chat interface. | |
| """ | |
| import os | |
| import re | |
| os.environ.setdefault("USE_TF", "0") | |
| import chromadb | |
| import gradio as gr | |
| import torch | |
| from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct" | |
| ADAPTER_REPO = "JoeStrout/miniscript-code-helper-lora" | |
| RAG_DIR = "./RAG_sources" | |
| DB_DIR = "./chroma_db" | |
| COLLECTION = "miniscript_docs" | |
| EMBEDDING_MODEL = "all-MiniLM-L6-v2" | |
| TOP_K = 5 | |
| MAX_NEW_TOKENS = 1024 | |
| MAX_CHUNK_CHARS = 1500 | |
| BASE_SYSTEM_PROMPT = "You are a helpful assistant specializing in MiniScript programming." | |
| # --------------------------------------------------------------------------- | |
| # RAG index builder (inline so app is self-contained) | |
| # --------------------------------------------------------------------------- | |
| def strip_leanpub(text: str) -> str: | |
| lines = text.splitlines() | |
| cleaned = [] | |
| for line in lines: | |
| if re.match(r'^\s*\{(chapterHead|width|i:|caption|pagebreak|startingPageNum)', line): | |
| m = re.search(r'\{caption:\s*"([^"]+)"\}', line) | |
| if m: | |
| cleaned.append(f"[{m.group(1)}]") | |
| continue | |
| if re.match(r'^\s*!\[.*\]\(.*\)\s*$', line): | |
| continue | |
| line = re.sub(r'^([QADX])>\s?', '', line) | |
| cleaned.append(line) | |
| return '\n'.join(cleaned) | |
| def split_long_chunk(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list: | |
| if len(text) <= max_chars: | |
| return [text] | |
| paragraphs = re.split(r'\n\n+', text) | |
| chunks, current = [], "" | |
| for para in paragraphs: | |
| if current and len(current) + len(para) + 2 > max_chars: | |
| chunks.append(current.strip()) | |
| current = para | |
| else: | |
| current = current + "\n\n" + para if current else para | |
| if current.strip(): | |
| chunks.append(current.strip()) | |
| return chunks | |
| def chunk_document(text: str, filename: str) -> list: | |
| is_txt = filename.endswith('.txt') | |
| if is_txt: | |
| text = strip_leanpub(text) | |
| lines = text.splitlines() | |
| chunks, current_section, current_lines = [], filename, [] | |
| def flush(): | |
| body = '\n'.join(current_lines).strip() | |
| if not body: | |
| return | |
| for part in split_long_chunk(body): | |
| if part.strip(): | |
| chunks.append({"text": part, "source": filename, "section": current_section}) | |
| for line in lines: | |
| heading = None | |
| if is_txt: | |
| m = re.match(r'^(#{1,4})\s+(.*)', line) | |
| if m: | |
| heading = m.group(2).strip() | |
| elif re.match(r'^#{1,4}\s', line): | |
| heading = re.sub(r'^#+\s*', '', line).strip() | |
| if heading: | |
| flush() | |
| current_section = heading | |
| current_lines = [] | |
| else: | |
| current_lines.append(line) | |
| flush() | |
| return chunks | |
| def build_rag_index(): | |
| print(f"Building ChromaDB index from {RAG_DIR}/ ...") | |
| embedding_fn = SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL) | |
| client = chromadb.PersistentClient(path=DB_DIR) | |
| existing = [c.name for c in client.list_collections()] | |
| if COLLECTION in existing: | |
| col = client.get_collection(name=COLLECTION, embedding_function=embedding_fn) | |
| print(f" Reusing existing collection ({col.count()} chunks)") | |
| return col | |
| col = client.create_collection( | |
| name=COLLECTION, | |
| embedding_function=embedding_fn, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| source_files = sorted(f for f in os.listdir(RAG_DIR) if f.endswith(('.md', '.txt'))) | |
| all_chunks = [] | |
| for fname in source_files: | |
| with open(os.path.join(RAG_DIR, fname), encoding='utf-8') as f: | |
| text = f.read() | |
| chunks = chunk_document(text, fname) | |
| print(f" {fname}: {len(chunks)} chunks") | |
| all_chunks.extend(chunks) | |
| BATCH = 100 | |
| for i in range(0, len(all_chunks), BATCH): | |
| batch = all_chunks[i:i + BATCH] | |
| col.add( | |
| ids=[f"chunk_{i + j}" for j in range(len(batch))], | |
| documents=[c["text"] for c in batch], | |
| metadatas=[{"source": c["source"], "section": c["section"]} for c in batch], | |
| ) | |
| print(f" Indexed {col.count()} chunks total.") | |
| return col | |
| # --------------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------------- | |
| def load_model(): | |
| print(f"Loading tokenizer from {ADAPTER_REPO} ...") | |
| tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO) | |
| print(f"Loading base model {BASE_MODEL} in 4-bit ...") | |
| bnb_cfg = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| quantization_config=bnb_cfg, | |
| device_map="auto", | |
| ) | |
| print(f"Loading LoRA adapter from {ADAPTER_REPO} ...") | |
| model = PeftModel.from_pretrained(base, ADAPTER_REPO) | |
| model.eval() | |
| print("Model ready!") | |
| return tokenizer, model | |
| # --------------------------------------------------------------------------- | |
| # Startup | |
| # --------------------------------------------------------------------------- | |
| collection = build_rag_index() | |
| tokenizer, model = load_model() | |
| # --------------------------------------------------------------------------- | |
| # Chat logic | |
| # --------------------------------------------------------------------------- | |
| def build_system_prompt(results: dict) -> str: | |
| if not results or not results["documents"] or not results["documents"][0]: | |
| return BASE_SYSTEM_PROMPT | |
| parts = [] | |
| for doc, meta in zip(results["documents"][0], results["metadatas"][0]): | |
| parts.append(f"[Source: {meta['source']}, Section: {meta['section']}]\n{doc}") | |
| context = "\n\n".join(parts) | |
| return ( | |
| f"{BASE_SYSTEM_PROMPT}\n\n" | |
| f"Use the following reference material to help answer the user's question:\n\n" | |
| f"{context}" | |
| ) | |
| def chat(message: str, history: list) -> str: | |
| results = collection.query(query_texts=[message], n_results=TOP_K) | |
| system_prompt = build_system_prompt(results) | |
| messages = [{"role": "system", "content": system_prompt}] | |
| # Gradio 5: history is list of {"role": ..., "content": ...} dicts | |
| for entry in history: | |
| messages.append({"role": entry["role"], "content": entry["content"]}) | |
| messages.append({"role": "user", "content": message}) | |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS) | |
| response = tokenizer.decode(output[0][len(inputs.input_ids[0]):], skip_special_tokens=True) | |
| return response | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| demo = gr.ChatInterface( | |
| fn=chat, | |
| title="MiniScript Code Helper", | |
| description=( | |
| "Ask questions about the [MiniScript](https://miniscript.org) programming language. " | |
| "Powered by a fine-tuned Qwen2.5-Coder-7B-Instruct model with RAG over MiniScript documentation." | |
| ), | |
| examples=[ | |
| "How do I define a function in MiniScript?", | |
| "How do I iterate over a list?", | |
| "What is the difference between `and` and `&&` in MiniScript?", | |
| "How do I read a file in MiniScript?", | |
| ], | |
| cache_examples=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |