| import os |
| import dotenv |
| from time import time |
| import streamlit as st |
| import logging |
|
|
| |
| os.environ["HF_HOME"] = "/tmp/.cache/huggingface" |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface" |
| os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/.cache/huggingface" |
|
|
| |
| os.makedirs("/tmp/.cache/huggingface", exist_ok=True) |
| os.makedirs("/tmp/chroma_persistent_db", exist_ok=True) |
| os.makedirs("/tmp/source_files", exist_ok=True) |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| from langchain_community.document_loaders.text import TextLoader |
| from langchain_community.document_loaders import ( |
| WebBaseLoader, |
| PyPDFLoader, |
| Docx2txtLoader, |
| ) |
| from langchain_community.vectorstores import Chroma |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_huggingface import HuggingFaceEmbeddings |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
| from langchain.chains.combine_documents import create_stuff_documents_chain |
|
|
| dotenv.load_dotenv() |
|
|
| os.environ["USER_AGENT"] = "myagent" |
| DB_DOCS_LIMIT = 10 |
|
|
| def clean_temp_files(): |
| """Clean up temporary files to prevent storage issues""" |
| try: |
| for folder in ["/tmp/source_files"]: |
| for filename in os.listdir(folder): |
| file_path = os.path.join(folder, filename) |
| if os.path.isfile(file_path): |
| os.unlink(file_path) |
| except Exception as e: |
| logger.warning(f"Error cleaning temp files: {e}") |
|
|
| def stream_llm_response(llm_stream, messages): |
| response_message = "" |
| for chunk in llm_stream.stream(messages): |
| response_message += chunk.content |
| yield chunk |
| st.session_state.messages.append({"role": "assistant", "content": response_message}) |
|
|
| def load_doc_to_db(): |
| if "rag_docs" in st.session_state and st.session_state.rag_docs: |
| docs = [] |
| for doc_file in st.session_state.rag_docs: |
| if doc_file.name not in st.session_state.rag_sources: |
| if len(st.session_state.rag_sources) < DB_DOCS_LIMIT: |
| try: |
| file_path = f"/tmp/source_files/{doc_file.name}" |
| with open(file_path, "wb") as file: |
| file.write(doc_file.getbuffer()) |
| |
| if doc_file.type == "application/pdf": |
| loader = PyPDFLoader(file_path) |
| elif doc_file.name.endswith(".docx"): |
| loader = Docx2txtLoader(file_path) |
| elif doc_file.type in ["text/plain", "text/markdown"]: |
| loader = TextLoader(file_path) |
| else: |
| st.warning(f"Unsupported document type: {doc_file.type}") |
| continue |
| |
| docs.extend(loader.load()) |
| st.session_state.rag_sources.append(doc_file.name) |
| logger.info(f"Successfully loaded document: {doc_file.name}") |
| except Exception as e: |
| st.toast(f"Error loading document {doc_file.name}: {str(e)}", icon="⚠️") |
| logger.error(f"Error loading document: {e}") |
| finally: |
| if os.path.exists(file_path): |
| os.remove(file_path) |
| else: |
| st.error(f"Max documents reached ({DB_DOCS_LIMIT}).") |
| if docs: |
| _split_and_load_docs(docs) |
| st.toast("Documents loaded successfully.", icon="✅") |
| clean_temp_files() |
|
|
| def load_url_to_db(): |
| if "rag_url" in st.session_state and st.session_state.rag_url: |
| url = st.session_state.rag_url |
| docs = [] |
| if url not in st.session_state.rag_sources: |
| if len(st.session_state.rag_sources) < DB_DOCS_LIMIT: |
| try: |
| loader = WebBaseLoader(url) |
| docs.extend(loader.load()) |
| st.session_state.rag_sources.append(url) |
| logger.info(f"Successfully loaded URL: {url}") |
| except Exception as e: |
| st.error(f"Error loading from URL {url}: {str(e)}") |
| logger.error(f"Error loading URL: {e}") |
| if docs: |
| _split_and_load_docs(docs) |
| st.toast(f"Loaded content from URL: {url}", icon="✅") |
| else: |
| st.error(f"Max documents reached ({DB_DOCS_LIMIT}).") |
|
|
| def initialize_vector_db(docs): |
| embedding = HuggingFaceEmbeddings( |
| model_name="BAAI/bge-large-en-v1.5", |
| model_kwargs={'device': 'cpu'}, |
| encode_kwargs={'normalize_embeddings': False}, |
| cache_folder="/tmp/.cache" |
| ) |
|
|
| persist_dir = "/tmp/chroma_persistent_db" |
| collection_name = "persistent_collection" |
|
|
| vector_db = Chroma.from_documents( |
| documents=docs, |
| embedding=embedding, |
| persist_directory=persist_dir, |
| collection_name=collection_name |
| ) |
|
|
| vector_db.persist() |
| logger.info("Vector database initialized and persisted") |
| return vector_db |
|
|
| def _split_and_load_docs(docs): |
| text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=1000, |
| chunk_overlap=200, |
| ) |
|
|
| chunks = text_splitter.split_documents(docs) |
|
|
| if "vector_db" not in st.session_state: |
| st.session_state.vector_db = initialize_vector_db(chunks) |
| else: |
| st.session_state.vector_db.add_documents(chunks) |
| st.session_state.vector_db.persist() |
| logger.info("Added new documents to existing vector database") |
|
|
| def _get_context_retriever_chain(vector_db, llm): |
| retriever = vector_db.as_retriever() |
| prompt = ChatPromptTemplate.from_messages([ |
| MessagesPlaceholder(variable_name="messages"), |
| ("user", "{input}"), |
| ("user", "Given the above conversation, generate a search query to find relevant information.") |
| ]) |
| return create_history_aware_retriever(llm, retriever, prompt) |
|
|
| def get_conversational_rag_chain(llm): |
| retriever_chain = _get_context_retriever_chain |