| from tools import MenuTool, CartTool, OrderTool, greetings_function | |
| from data_section import data_bifercation | |
| from tools.prompts import tool_prompt_function | |
| from config import settings | |
| from utils import client | |
| from context import ollama_context_query, summarised_output | |
| import pandas as pd | |
| class ReactAgent: | |
| def __init__(self): | |
| self.store_id = "66dff7a04b17303d454d4bbc" | |
| self.brand_id = "66cec85093c5b0896c9125c5" | |
| columns = ["category", "item", "price"] | |
| main_data, category, items = data_bifercation(self.store_id, self.brand_id) | |
| self.items = items | |
| self.category = category | |
| df = pd.DataFrame(main_data) | |
| df.columns = columns | |
| df["item"] = df["item"].str.lower() | |
| df["category"] = df["category"].str.lower() | |
| self.df = df | |
| self.menu_tool = MenuTool(df) | |
| self.cart_tool = CartTool(df) | |
| self.order_tool = OrderTool(df) | |
| self.llm_client = client | |
| def handle_query(self, session_id, query, chat_history): | |
| prompt = tool_prompt_function(current_query=query, session_id=session_id) | |
| context_query, greet_bool = ollama_context_query( | |
| chat_history=chat_history, user_query=query | |
| ) | |
| if not greet_bool: | |
| return greetings_function(query) | |
| if context_query in ["MenuTool", "CartTool", "OrderTool"]: | |
| user_message_content = query | |
| else: | |
| user_message_content = context_query | |
| messages = [ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": user_message_content}, | |
| ] | |
| response = self.llm_client.chat( | |
| model=settings.MODEL_NAME, | |
| messages=messages, | |
| tools=[ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "menu_tool", | |
| "description": "Fetch the restaurant menu based on user input", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "query": { | |
| "type": "string", | |
| "description": "User's natural language query for the menu", | |
| } | |
| }, | |
| "required": ["query"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "cart_tool", | |
| "description": "Manage the cart based on user input (add/remove/view)", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "query": { | |
| "type": "string", | |
| "description": "User's cart query to add/remove/view items", | |
| }, | |
| "session_id": { | |
| "type": "string", | |
| "description": "current session id", | |
| }, | |
| }, | |
| "required": ["query", "session_id"], | |
| }, | |
| }, | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "order_tool", | |
| "description": "Handle order and checkout functionality", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "query": { | |
| "type": "string", | |
| "description": "User's request to place an order", | |
| }, | |
| }, | |
| "required": ["query"], | |
| }, | |
| }, | |
| }, | |
| ], | |
| ) | |
| print("----" * 30) | |
| print(query) | |
| print("----" * 30) | |
| print(response) | |
| print("----" * 30) | |
| tool_call = response["message"].get("tool_calls", []) | |
| tool_calls = [ | |
| tool_call[i].get("function").get("name") | |
| for i in range(0, len(response["message"].get("tool_calls", []))) | |
| ] | |
| print("----" * 30) | |
| print(tool_call) | |
| print("----" * 30) | |
| print(tool_calls) | |
| print("----" * 30) | |
| tool_responses = [] | |
| for tool_name in tool_calls: | |
| if tool_name == "menu_tool": | |
| tool_call_index = next( | |
| ( | |
| index | |
| for index, call in enumerate(tool_call) | |
| if call["function"]["name"] == "menu_tool" | |
| ), | |
| None, | |
| ) | |
| tool_args = tool_call[tool_call_index]["function"]["arguments"] | |
| response = self.menu_tool.run(tool_args["query"], session_id) | |
| print("menu tool response :: ", response) | |
| tool_responses.append(response) | |
| elif tool_name == "cart_tool": | |
| tool_call_index = next( | |
| ( | |
| index | |
| for index, call in enumerate(tool_call) | |
| if call["function"]["name"] == "cart_tool" | |
| ), | |
| None, | |
| ) | |
| tool_args = tool_call[tool_call_index]["function"]["arguments"] | |
| response = self.cart_tool.run(tool_args["query"], session_id=session_id) | |
| print("cart tool response :: ", response) | |
| tool_responses.append(response) | |
| elif tool_name == "order_tool": | |
| tool_call_index = next( | |
| ( | |
| index | |
| for index, call in enumerate(tool_call) | |
| if call["function"]["name"] == "order_tool" | |
| ), | |
| None, | |
| ) | |
| tool_args = tool_call[tool_call_index]["function"]["arguments"] | |
| print("order tool response :: ", response) | |
| response = self.order_tool.run( | |
| df=self.df, | |
| session_id=session_id, | |
| category=self.category, | |
| items=self.items, | |
| store_id=self.store_id, | |
| brand_id=self.brand_id, | |
| ) | |
| tool_responses.append(response) | |
| combined_response = summarised_output( | |
| messages=tool_responses, | |
| chat_history=chat_history, | |
| context_query=context_query, | |
| user_query=query, | |
| ) | |
| return combined_response | |