ShanenThomas commited on
Commit
4f729f0
Β·
verified Β·
1 Parent(s): caf69ac

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import faiss
3
+ import numpy as np
4
+ import gradio as gr
5
+ from typing import List, Tuple
6
+ from pypdf import PdfReader
7
+ from sentence_transformers import SentenceTransformer
8
+ from huggingface_hub import InferenceClient
9
+
10
+ # ==============================
11
+ # Config
12
+ # ==============================
13
+ GEN_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
14
+ HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") # set in Space Secrets
15
+ EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
16
+ CHUNK_SIZE = 900
17
+ CHUNK_OVERLAP = 150
18
+ TOP_K = 4
19
+
20
+ # ==============================
21
+ # Globals (lifetime of the Space)
22
+ # ==============================
23
+ emb = SentenceTransformer(EMB_MODEL_NAME)
24
+ index = None # FAISS index (inner product / cosine)
25
+ doc_chunks: List[str] = [] # text chunks
26
+ doc_meta: List[dict] = [] # {"file": "..."}
27
+ client = InferenceClient(model=GEN_MODEL, token=HF_TOKEN)
28
+
29
+ # ==============================
30
+ # Helpers
31
+ # ==============================
32
+ def _chunk_text(text: str, size: int, overlap: int) -> List[str]:
33
+ chunks = []
34
+ start = 0
35
+ n = len(text)
36
+ step = size - overlap
37
+ while start < n:
38
+ end = min(start + size, n)
39
+ chunks.append(text[start:end])
40
+ start += step
41
+ return [c.strip() for c in chunks if c.strip()]
42
+
43
+ def _embed(texts: List[str]) -> np.ndarray:
44
+ # 384-d for MiniLM; normalize for cosine/IP search
45
+ X = emb.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
46
+ return np.asarray(X, dtype=np.float32)
47
+
48
+ def _ensure_index(dim: int):
49
+ global index
50
+ index = faiss.IndexFlatIP(dim) # cosine via normalized vectors
51
+
52
+ def _extract_text_from_pdf(path: str) -> str:
53
+ reader = PdfReader(path)
54
+ pages = []
55
+ for p in reader.pages:
56
+ t = p.extract_text() or ""
57
+ pages.append(t)
58
+ return "\n".join(pages)
59
+
60
+ # ==============================
61
+ # Build index
62
+ # ==============================
63
+ def build_from_pdfs(files) -> str:
64
+ global index, doc_chunks, doc_meta
65
+ doc_chunks, doc_meta = [], []
66
+
67
+ # 1) read PDFs β†’ 2) chunk β†’ collect
68
+ for f in files:
69
+ try:
70
+ text = _extract_text_from_pdf(f.name)
71
+ except Exception as e:
72
+ return f"Failed to read {os.path.basename(f.name)}: {e}"
73
+ chunks = _chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP)
74
+ for c in chunks:
75
+ doc_chunks.append(c)
76
+ doc_meta.append({"file": os.path.basename(f.name)})
77
+
78
+ if not doc_chunks:
79
+ return "No text extracted. Check your PDFs."
80
+
81
+ # 3) embeddings β†’ FAISS
82
+ E = _embed(doc_chunks)
83
+ _ensure_index(E.shape[1])
84
+ index.add(E)
85
+
86
+ return f"Indexed {len(doc_chunks)} chunks from {len(files)} file(s)."
87
+
88
+ # ==============================
89
+ # Retrieval + Generation
90
+ # ==============================
91
+ def _retrieve(query: str, k: int = TOP_K) -> Tuple[List[int], List[str]]:
92
+ qv = _embed([query]) # shape (1, d)
93
+ sims, idxs = index.search(qv, k) # inner product similarity
94
+ ids = idxs[0].tolist()
95
+ # Filter out -1 (in case FAISS returns for empty)
96
+ ids = [i for i in ids if i >= 0]
97
+ return ids, [doc_chunks[i] for i in ids]
98
+
99
+ SYSTEM_PROMPT = (
100
+ "You are a helpful assistant. Use the given CONTEXT to answer the QUESTION.\n"
101
+ "If the answer is not in the context, say you don't know.\n"
102
+ "Provide a concise answer and list source filenames as [source: file.pdf] at the end."
103
+ )
104
+
105
+ def _mistral_prompt(question: str, context: str) -> str:
106
+ # Simple Mistral-instruct prompt format
107
+ return (
108
+ f"[INST] {SYSTEM_PROMPT}\n\n"
109
+ f"QUESTION: {question}\n\n"
110
+ f"CONTEXT:\n{context}\n"
111
+ f"[/INST]"
112
+ )
113
+
114
+ def answer(question: str) -> str:
115
+ if not question.strip():
116
+ return "Ask a question."
117
+ if index is None or not doc_chunks:
118
+ return "Upload PDFs and click **Build Index** first."
119
+
120
+ ids, ctx_chunks = _retrieve(question, TOP_K)
121
+ # keep contexts reasonably short per chunk
122
+ previews = []
123
+ contexts = []
124
+ files = []
125
+ for rank, i in enumerate(ids, start=1):
126
+ chunk = doc_chunks[i][:1000]
127
+ fname = doc_meta[i]["file"]
128
+ contexts.append(f"[{rank}] {fname}\n{chunk}")
129
+ previews.append(f"[{rank}] {fname}")
130
+ files.append(fname)
131
+
132
+ context_str = "\n\n---\n".join(contexts)
133
+ prompt = _mistral_prompt(question, context_str)
134
+
135
+ try:
136
+ # Use hosted Inference API; returns a single string
137
+ out = client.text_generation(
138
+ prompt,
139
+ max_new_tokens=512,
140
+ temperature=0.2,
141
+ top_p=0.95,
142
+ repetition_penalty=1.05,
143
+ do_sample=True,
144
+ return_full_text=False,
145
+ )
146
+ # Ensure sources are visible at the end
147
+ unique_files = ", ".join(sorted(set(files)))
148
+ return f"{out.strip()}\n\nSources: {unique_files}"
149
+ except Exception as e:
150
+ return f"Generation error: {e}\n(Verify your HUGGINGFACEHUB_API_TOKEN and model name.)"
151
+
152
+ # ==============================
153
+ # UI
154
+ # ==============================
155
+ with gr.Blocks(title="Mistral 7B PDF-RAG") as demo:
156
+ gr.Markdown("# πŸ“š PDF-RAG (Mistral-7B-Instruct)\nUpload PDFs β†’ Build Index β†’ Ask questions. Answers cite sources.")
157
+
158
+ with gr.Row():
159
+ with gr.Column(scale=1):
160
+ files = gr.File(file_count="multiple", file_types=[".pdf"], label="Upload PDF books")
161
+ build_btn = gr.Button("Build Index", variant="primary")
162
+ status = gr.Markdown()
163
+ with gr.Column(scale=2):
164
+ q = gr.Textbox(label="Ask a question", placeholder="What does the book say about ...?")
165
+ ask_btn = gr.Button("Ask ➜")
166
+ a = gr.Markdown()
167
+
168
+ build_btn.click(build_from_pdfs, inputs=[files], outputs=[status])
169
+ ask_btn.click(answer, inputs=[q], outputs=[a])
170
+ q.submit(answer, inputs=[q], outputs=[a]) # hit Enter to ask
171
+
172
+ if __name__ == "__main__":
173
+ demo.launch()