File size: 8,148 Bytes
9a1939c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b59c36e
 
 
9a1939c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#!/usr/bin/env python3
"""Gradio Space: MiniScript Code Helper (LoRA + RAG).

Loads the fine-tuned Qwen2.5-Coder-7B-Instruct LoRA adapter and a ChromaDB
vector index built from MiniScript documentation, then serves a chat interface.
"""

import os
import re

os.environ.setdefault("USE_TF", "0")

import chromadb
import gradio as gr
import torch
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct"
ADAPTER_REPO = "JoeStrout/miniscript-code-helper-lora"
RAG_DIR = "./RAG_sources"
DB_DIR = "./chroma_db"
COLLECTION = "miniscript_docs"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
TOP_K = 5
MAX_NEW_TOKENS = 1024
MAX_CHUNK_CHARS = 1500

BASE_SYSTEM_PROMPT = "You are a helpful assistant specializing in MiniScript programming."

# ---------------------------------------------------------------------------
# RAG index builder (inline so app is self-contained)
# ---------------------------------------------------------------------------

def strip_leanpub(text: str) -> str:
    lines = text.splitlines()
    cleaned = []
    for line in lines:
        if re.match(r'^\s*\{(chapterHead|width|i:|caption|pagebreak|startingPageNum)', line):
            m = re.search(r'\{caption:\s*"([^"]+)"\}', line)
            if m:
                cleaned.append(f"[{m.group(1)}]")
            continue
        if re.match(r'^\s*!\[.*\]\(.*\)\s*$', line):
            continue
        line = re.sub(r'^([QADX])>\s?', '', line)
        cleaned.append(line)
    return '\n'.join(cleaned)


def split_long_chunk(text: str, max_chars: int = MAX_CHUNK_CHARS) -> list:
    if len(text) <= max_chars:
        return [text]
    paragraphs = re.split(r'\n\n+', text)
    chunks, current = [], ""
    for para in paragraphs:
        if current and len(current) + len(para) + 2 > max_chars:
            chunks.append(current.strip())
            current = para
        else:
            current = current + "\n\n" + para if current else para
    if current.strip():
        chunks.append(current.strip())
    return chunks


def chunk_document(text: str, filename: str) -> list:
    is_txt = filename.endswith('.txt')
    if is_txt:
        text = strip_leanpub(text)
    lines = text.splitlines()
    chunks, current_section, current_lines = [], filename, []

    def flush():
        body = '\n'.join(current_lines).strip()
        if not body:
            return
        for part in split_long_chunk(body):
            if part.strip():
                chunks.append({"text": part, "source": filename, "section": current_section})

    for line in lines:
        heading = None
        if is_txt:
            m = re.match(r'^(#{1,4})\s+(.*)', line)
            if m:
                heading = m.group(2).strip()
        elif re.match(r'^#{1,4}\s', line):
            heading = re.sub(r'^#+\s*', '', line).strip()
        if heading:
            flush()
            current_section = heading
            current_lines = []
        else:
            current_lines.append(line)
    flush()
    return chunks


def build_rag_index():
    print(f"Building ChromaDB index from {RAG_DIR}/ ...")
    embedding_fn = SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL)
    client = chromadb.PersistentClient(path=DB_DIR)

    existing = [c.name for c in client.list_collections()]
    if COLLECTION in existing:
        col = client.get_collection(name=COLLECTION, embedding_function=embedding_fn)
        print(f"  Reusing existing collection ({col.count()} chunks)")
        return col

    col = client.create_collection(
        name=COLLECTION,
        embedding_function=embedding_fn,
        metadata={"hnsw:space": "cosine"},
    )
    source_files = sorted(f for f in os.listdir(RAG_DIR) if f.endswith(('.md', '.txt')))
    all_chunks = []
    for fname in source_files:
        with open(os.path.join(RAG_DIR, fname), encoding='utf-8') as f:
            text = f.read()
        chunks = chunk_document(text, fname)
        print(f"  {fname}: {len(chunks)} chunks")
        all_chunks.extend(chunks)

    BATCH = 100
    for i in range(0, len(all_chunks), BATCH):
        batch = all_chunks[i:i + BATCH]
        col.add(
            ids=[f"chunk_{i + j}" for j in range(len(batch))],
            documents=[c["text"] for c in batch],
            metadatas=[{"source": c["source"], "section": c["section"]} for c in batch],
        )
    print(f"  Indexed {col.count()} chunks total.")
    return col


# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------

def load_model():
    print(f"Loading tokenizer from {ADAPTER_REPO} ...")
    tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO)

    print(f"Loading base model {BASE_MODEL} in 4-bit ...")
    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    base = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_cfg,
        device_map="auto",
    )

    print(f"Loading LoRA adapter from {ADAPTER_REPO} ...")
    model = PeftModel.from_pretrained(base, ADAPTER_REPO)
    model.eval()
    print("Model ready!")
    return tokenizer, model


# ---------------------------------------------------------------------------
# Startup
# ---------------------------------------------------------------------------

collection = build_rag_index()
tokenizer, model = load_model()


# ---------------------------------------------------------------------------
# Chat logic
# ---------------------------------------------------------------------------

def build_system_prompt(results: dict) -> str:
    if not results or not results["documents"] or not results["documents"][0]:
        return BASE_SYSTEM_PROMPT
    parts = []
    for doc, meta in zip(results["documents"][0], results["metadatas"][0]):
        parts.append(f"[Source: {meta['source']}, Section: {meta['section']}]\n{doc}")
    context = "\n\n".join(parts)
    return (
        f"{BASE_SYSTEM_PROMPT}\n\n"
        f"Use the following reference material to help answer the user's question:\n\n"
        f"{context}"
    )


def chat(message: str, history: list) -> str:
    results = collection.query(query_texts=[message], n_results=TOP_K)
    system_prompt = build_system_prompt(results)

    messages = [{"role": "system", "content": system_prompt}]
    # Gradio 5: history is list of {"role": ..., "content": ...} dicts
    for entry in history:
        messages.append({"role": entry["role"], "content": entry["content"]})
    messages.append({"role": "user", "content": message})

    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([text], return_tensors="pt").to(model.device)
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS)
    response = tokenizer.decode(output[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
    return response


# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------

demo = gr.ChatInterface(
    fn=chat,
    title="MiniScript Code Helper",
    description=(
        "Ask questions about the [MiniScript](https://miniscript.org) programming language. "
        "Powered by a fine-tuned Qwen2.5-Coder-7B-Instruct model with RAG over MiniScript documentation."
    ),
    examples=[
        "How do I define a function in MiniScript?",
        "How do I iterate over a list?",
        "What is the difference between `and` and `&&` in MiniScript?",
        "How do I read a file in MiniScript?",
    ],
    cache_examples=False,
)

if __name__ == "__main__":
    demo.launch()