| from operator import itemgetter |
| import chainlit as cl |
| from langchain.schema.runnable import RunnablePassthrough |
| from langchain.vectorstores import FAISS |
| from langchain.chains import RetrievalQA |
| from langchain.chat_models import ChatOpenAI |
| from langchain.prompts.chat import ( |
| ChatPromptTemplate, |
| SystemMessagePromptTemplate, |
| HumanMessagePromptTemplate, |
| ) |
|
|
| from utils import ArxivLoader, PineconeIndexer |
|
|
| system_template = """ |
| Use the provided context to answer the user's query. |
| |
| You may not answer the user's query unless there is specific context in the following text. |
| |
| If you do not know the answer, or cannot answer, please respond with "I don't know". |
| |
| Context: |
| {context} |
| """ |
|
|
| messages = [ |
| SystemMessagePromptTemplate.from_template(system_template), |
| HumanMessagePromptTemplate.from_template("{question}"), |
| ] |
|
|
| prompt = ChatPromptTemplate(messages=messages) |
| chain_type_kwargs = {"prompt": prompt} |
|
|
| @cl.author_rename |
| def rename(orig_author: str): |
| rename_dict = {"RetrievalQA": "Learning about Nuclear Fission"} |
| return rename_dict.get(orig_author, orig_author) |
|
|
| @cl.on_chat_start |
| async def start_chat(): |
|
|
| msg = cl.Message(content=f"Initializing the Application...") |
| await msg.send() |
|
|
| |
| axloader = ArxivLoader() |
| axloader.main() |
|
|
| |
| pi = PineconeIndexer() |
| pi.load_embedder() |
| retriever=pi.get_vectorstore().as_retriever() |
| print(pi.index.describe_index_stats()) |
|
|
| |
| llm = ChatOpenAI( |
| model="gpt-3.5-turbo", |
| temperature=0 |
| ) |
|
|
| msg = cl.Message(content=f"Application is ready !") |
| await msg.send() |
|
|
| cl.user_session.set("llm", llm) |
| cl.user_session.set("retriever", retriever) |
|
|
| @cl.on_message |
| async def main(message: cl.Message): |
|
|
| llm = cl.user_session.get("llm") |
| retriever = cl.user_session.get("retriever") |
|
|
| retrieval_augmented_qa_chain = ( |
| {"context": itemgetter("question") | retriever, |
| "question": itemgetter("question") |
| } |
| | RunnablePassthrough.assign( |
| context=itemgetter("context") |
| ) |
| | { |
| "response": prompt | llm, |
| "context": itemgetter("context"), |
| } |
| ) |
|
|
| answer = retrieval_augmented_qa_chain.invoke({"question" : message.content}) |
| |
| await cl.Message(content=answer["response"].content).send() |
|
|
|
|
|
|
|
|
|
|
|
|