QuickQuery / app.py
swisscondor's picture
Update app.py
cf68bef verified
import streamlit as st
import chromadb
import torch
from transformers import pipeline, AutoModel, AutoTokenizer
import numpy as np
from PyPDF2 import PdfReader
import os
# Load sentence transformer model for embeddings
def load_embedding_model():
model = AutoModel.from_pretrained("cross-encoder/qnli-electra-base")
tokenizer = AutoTokenizer.from_pretrained("cross-encoder/qnli-electra-base")
return model, tokenizer
# Generate embeddings for text
def generate_embedding(model, tokenizer, text):
# Tokenize the text
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
# Generate embeddings
with torch.no_grad():
outputs = model(**inputs)
# Use the last hidden state as embedding
embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
return embeddings
# Initialize Hugging Face pipeline for question answering
def load_qa_pipeline():
return pipeline("question-answering", model="deepset/roberta-base-squad2")
# Extract text from PDF
def extract_pdf_text(pdf_file):
reader = PdfReader(pdf_file)
text = ""
for page in reader.pages:
text += page.extract_text() + "\n"
return text
# Split text into chunks
def split_text_into_chunks(text, chunk_size=500, overlap=100):
chunks = []
for i in range(0, len(text), chunk_size - overlap):
chunks.append(text[i:i+chunk_size])
return chunks
# Create ChromaDB collection with embeddings
def create_chroma_collection(chunks, model, tokenizer):
# Use persistent client to avoid memory issues
client = chromadb.PersistentClient(path="./chroma_db")
# Create a unique collection name
collection_name = f"pdf_qa_collection_{int(torch.rand(1).item() * 10000)}"
# Create collection
collection = client.create_collection(name=collection_name)
# Add chunks to collection with embeddings
for i, chunk in enumerate(chunks):
# Generate embedding for the chunk
embedding = generate_embedding(model, tokenizer, chunk)
collection.add(
ids=[f"chunk_{i}"],
documents=[chunk],
embeddings=[embedding.tolist()]
)
return client, collection, collection_name
# Retrieve most relevant context
def retrieve_context(collection, question, model, tokenizer, top_k=3):
# Generate embedding for the question
question_embedding = generate_embedding(model, tokenizer, question)
# Query the collection
results = collection.query(
query_embeddings=[question_embedding.tolist()],
n_results=top_k
)
return results['documents'][0]
# Main Streamlit app
def main():
st.title("PDF Question Answering App")
# Load embedding model
embedding_model, tokenizer = load_embedding_model()
# File uploader
uploaded_file = st.file_uploader("Upload PDF", type=['pdf'])
# Question input
question = st.text_input("Enter your question")
# Run button
if st.button("Get Answer"):
if uploaded_file and question:
try:
# Load QA pipeline
qa_pipeline = load_qa_pipeline()
# Extract PDF text
pdf_text = extract_pdf_text(uploaded_file)
# Split text into chunks
text_chunks = split_text_into_chunks(pdf_text)
# Create ChromaDB collection with embeddings
client, collection, collection_name = create_chroma_collection(
text_chunks, embedding_model, tokenizer
)
# Retrieve context
contexts = retrieve_context(
collection, question, embedding_model, tokenizer
)
# Prepare answers
answers = []
for context in contexts:
result = qa_pipeline(question=question, context=context)
answers.append(result)
# Display best answer
best_answer = max(answers, key=lambda x: x['score'])
st.write("Answer:", best_answer['answer'])
st.write("Confidence Score:", best_answer['score'])
# Clean up ChromaDB collection
client.delete_collection(name=collection_name)
except Exception as e:
st.error(f"An error occurred: {e}")
if __name__ == "__main__":
main()