Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import chromadb | |
| import torch | |
| from transformers import pipeline, AutoModel, AutoTokenizer | |
| import numpy as np | |
| from PyPDF2 import PdfReader | |
| import os | |
| # Load sentence transformer model for embeddings | |
| def load_embedding_model(): | |
| model = AutoModel.from_pretrained("cross-encoder/qnli-electra-base") | |
| tokenizer = AutoTokenizer.from_pretrained("cross-encoder/qnli-electra-base") | |
| return model, tokenizer | |
| # Generate embeddings for text | |
| def generate_embedding(model, tokenizer, text): | |
| # Tokenize the text | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| # Generate embeddings | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Use the last hidden state as embedding | |
| embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().numpy() | |
| return embeddings | |
| # Initialize Hugging Face pipeline for question answering | |
| def load_qa_pipeline(): | |
| return pipeline("question-answering", model="deepset/roberta-base-squad2") | |
| # Extract text from PDF | |
| def extract_pdf_text(pdf_file): | |
| reader = PdfReader(pdf_file) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| # Split text into chunks | |
| def split_text_into_chunks(text, chunk_size=500, overlap=100): | |
| chunks = [] | |
| for i in range(0, len(text), chunk_size - overlap): | |
| chunks.append(text[i:i+chunk_size]) | |
| return chunks | |
| # Create ChromaDB collection with embeddings | |
| def create_chroma_collection(chunks, model, tokenizer): | |
| # Use persistent client to avoid memory issues | |
| client = chromadb.PersistentClient(path="./chroma_db") | |
| # Create a unique collection name | |
| collection_name = f"pdf_qa_collection_{int(torch.rand(1).item() * 10000)}" | |
| # Create collection | |
| collection = client.create_collection(name=collection_name) | |
| # Add chunks to collection with embeddings | |
| for i, chunk in enumerate(chunks): | |
| # Generate embedding for the chunk | |
| embedding = generate_embedding(model, tokenizer, chunk) | |
| collection.add( | |
| ids=[f"chunk_{i}"], | |
| documents=[chunk], | |
| embeddings=[embedding.tolist()] | |
| ) | |
| return client, collection, collection_name | |
| # Retrieve most relevant context | |
| def retrieve_context(collection, question, model, tokenizer, top_k=3): | |
| # Generate embedding for the question | |
| question_embedding = generate_embedding(model, tokenizer, question) | |
| # Query the collection | |
| results = collection.query( | |
| query_embeddings=[question_embedding.tolist()], | |
| n_results=top_k | |
| ) | |
| return results['documents'][0] | |
| # Main Streamlit app | |
| def main(): | |
| st.title("PDF Question Answering App") | |
| # Load embedding model | |
| embedding_model, tokenizer = load_embedding_model() | |
| # File uploader | |
| uploaded_file = st.file_uploader("Upload PDF", type=['pdf']) | |
| # Question input | |
| question = st.text_input("Enter your question") | |
| # Run button | |
| if st.button("Get Answer"): | |
| if uploaded_file and question: | |
| try: | |
| # Load QA pipeline | |
| qa_pipeline = load_qa_pipeline() | |
| # Extract PDF text | |
| pdf_text = extract_pdf_text(uploaded_file) | |
| # Split text into chunks | |
| text_chunks = split_text_into_chunks(pdf_text) | |
| # Create ChromaDB collection with embeddings | |
| client, collection, collection_name = create_chroma_collection( | |
| text_chunks, embedding_model, tokenizer | |
| ) | |
| # Retrieve context | |
| contexts = retrieve_context( | |
| collection, question, embedding_model, tokenizer | |
| ) | |
| # Prepare answers | |
| answers = [] | |
| for context in contexts: | |
| result = qa_pipeline(question=question, context=context) | |
| answers.append(result) | |
| # Display best answer | |
| best_answer = max(answers, key=lambda x: x['score']) | |
| st.write("Answer:", best_answer['answer']) | |
| st.write("Confidence Score:", best_answer['score']) | |
| # Clean up ChromaDB collection | |
| client.delete_collection(name=collection_name) | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| if __name__ == "__main__": | |
| main() |