import os import faiss import numpy as np import gradio as gr from typing import List, Tuple from pypdf import PdfReader from sentence_transformers import SentenceTransformer from huggingface_hub import InferenceClient # ============================== # Config # ============================== GEN_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") # set in Space Secrets EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" CHUNK_SIZE = 900 CHUNK_OVERLAP = 150 TOP_K = 4 # ============================== # Globals (lifetime of the Space) # ============================== emb = SentenceTransformer(EMB_MODEL_NAME) index = None # FAISS index (inner product / cosine) doc_chunks: List[str] = [] # text chunks doc_meta: List[dict] = [] # {"file": "..."} client = InferenceClient(model=GEN_MODEL, token=HF_TOKEN) # ============================== # Helpers # ============================== def _chunk_text(text: str, size: int, overlap: int) -> List[str]: chunks = [] start = 0 n = len(text) step = size - overlap while start < n: end = min(start + size, n) chunks.append(text[start:end]) start += step return [c.strip() for c in chunks if c.strip()] def _embed(texts: List[str]) -> np.ndarray: # 384-d for MiniLM; normalize for cosine/IP search X = emb.encode(texts, convert_to_numpy=True, normalize_embeddings=True) return np.asarray(X, dtype=np.float32) def _ensure_index(dim: int): global index index = faiss.IndexFlatIP(dim) # cosine via normalized vectors def _extract_text_from_pdf(path: str) -> str: reader = PdfReader(path) pages = [] for p in reader.pages: t = p.extract_text() or "" pages.append(t) return "\n".join(pages) # ============================== # Build index # ============================== def build_from_pdfs(files) -> str: global index, doc_chunks, doc_meta doc_chunks, doc_meta = [], [] # 1) read PDFs → 2) chunk → collect for f in files: try: text = _extract_text_from_pdf(f.name) except Exception as e: return f"Failed to read {os.path.basename(f.name)}: {e}" chunks = _chunk_text(text, CHUNK_SIZE, CHUNK_OVERLAP) for c in chunks: doc_chunks.append(c) doc_meta.append({"file": os.path.basename(f.name)}) if not doc_chunks: return "No text extracted. Check your PDFs." # 3) embeddings → FAISS E = _embed(doc_chunks) _ensure_index(E.shape[1]) index.add(E) return f"Indexed {len(doc_chunks)} chunks from {len(files)} file(s)." # ============================== # Retrieval + Generation # ============================== def _retrieve(query: str, k: int = TOP_K) -> Tuple[List[int], List[str]]: qv = _embed([query]) # shape (1, d) sims, idxs = index.search(qv, k) # inner product similarity ids = idxs[0].tolist() # Filter out -1 (in case FAISS returns for empty) ids = [i for i in ids if i >= 0] return ids, [doc_chunks[i] for i in ids] SYSTEM_PROMPT = ( "You are a helpful assistant. Use the given CONTEXT to answer the QUESTION.\n" "If the answer is not in the context, say you don't know.\n" "Provide a concise answer and list source filenames as [source: file.pdf] at the end." ) def _mistral_prompt(question: str, context: str) -> str: # Simple Mistral-instruct prompt format return ( f"[INST] {SYSTEM_PROMPT}\n\n" f"QUESTION: {question}\n\n" f"CONTEXT:\n{context}\n" f"[/INST]" ) def answer(question: str) -> str: if not question.strip(): return "Ask a question." if index is None or not doc_chunks: return "Upload PDFs and click **Build Index** first." ids, ctx_chunks = _retrieve(question, TOP_K) # keep contexts reasonably short per chunk previews = [] contexts = [] files = [] for rank, i in enumerate(ids, start=1): chunk = doc_chunks[i][:1000] fname = doc_meta[i]["file"] contexts.append(f"[{rank}] {fname}\n{chunk}") previews.append(f"[{rank}] {fname}") files.append(fname) context_str = "\n\n---\n".join(contexts) prompt = _mistral_prompt(question, context_str) try: # Use hosted Inference API; returns a single string out = client.text_generation( prompt, max_new_tokens=512, temperature=0.2, top_p=0.95, repetition_penalty=1.05, do_sample=True, return_full_text=False, ) # Ensure sources are visible at the end unique_files = ", ".join(sorted(set(files))) return f"{out.strip()}\n\nSources: {unique_files}" except Exception as e: return f"Generation error: {e}\n(Verify your HUGGINGFACEHUB_API_TOKEN and model name.)" # ============================== # UI # ============================== with gr.Blocks(title="Mistral 7B PDF-RAG") as demo: gr.Markdown("# 📚 PDF-RAG (Mistral-7B-Instruct)\nUpload PDFs → Build Index → Ask questions. Answers cite sources.") with gr.Row(): with gr.Column(scale=1): files = gr.File(file_count="multiple", file_types=[".pdf"], label="Upload PDF books") build_btn = gr.Button("Build Index", variant="primary") status = gr.Markdown() with gr.Column(scale=2): q = gr.Textbox(label="Ask a question", placeholder="What does the book say about ...?") ask_btn = gr.Button("Ask ➜") a = gr.Markdown() build_btn.click(build_from_pdfs, inputs=[files], outputs=[status]) ask_btn.click(answer, inputs=[q], outputs=[a]) q.submit(answer, inputs=[q], outputs=[a]) # hit Enter to ask if __name__ == "__main__": demo.launch()