expense_tracker / app.py
nagarmayank's picture
change to GROQ
5d0417f
raw
history blame
6.79 kB
from fastapi import FastAPI
import uvicorn
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import pygsheets
import json
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
from langchain_core.messages import SystemMessage, HumanMessage, AnyMessage
from langchain_ollama import ChatOllama
from langgraph.pregel import RetryPolicy
import json
from google.oauth2 import service_account
import os
from langchain_groq import ChatGroq
import groq
sheet_url = os.getenv("SHEET_URL")
GOOGLESHEETS_CREDENTIALS = os.getenv("GOOGLESHEETS_CREDENTIALS")
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
class TransactionParser(BaseModel):
"""This Pydantic class is used to parse the transaction message."""
amount: str = Field(description="The amount of the transaction strictly in decimal format. If the transaction is a credit or a reversal, then include negative sign. Otherwise it should always be positive. DO not insert currency symbol.", example="123.45")
dr_or_cr: str = Field(description="Identify if the transaction was debit (spent) or credit (received). Strictly choose one of the values - Debit or Credit")
receiver: str = Field(description="The recipient of the transaction. Identify the Merchant Name from the value.")
category: str = Field(description="The category of the transaction. The category of the transaction is linked to the Merchant Name. Strictly choose from one the of values - Shopping,EMI,Education,Miscellaneous,Grocery,Utility,House Help,Travel,Transport")
transaction_date: str = Field(description="The date of the transaction strictly in yyyy-mm-dd format. If the year is not provided then use current year.")
transaction_origin: str = Field(description="The origin of the transaction. Provide the card or account number as well.")
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], operator.add]
class Agent:
def __init__(self, model, system=""):
self.system = system
graph = StateGraph(AgentState)
graph.add_node("classify_txn_type", self.classify_txn_type, retry=RetryPolicy(retry_on=[groq.BadRequestError], max_attempts=2))
graph.add_node("parse_message", self.parse_message, retry=RetryPolicy(retry_on=[groq.BadRequestError], max_attempts=2))
graph.add_node("write_message", self.write_message)
graph.add_conditional_edges(
"classify_txn_type",
self.check_txn_and_decide,
{True: "parse_message", False: END}
)
graph.add_edge("parse_message", "write_message")
graph.add_edge("write_message", END)
graph.set_entry_point("classify_txn_type")
self.graph = graph.compile()
self.model = model
def classify_txn_type(self, state: AgentState) -> AgentState:
print("Classifying transaction type...")
messages = state["messages"]
if self.system:
messages = [SystemMessage(content=self.system)] + messages
message = self.model.invoke(messages)
print("Classifying transaction type completed.")
return {"messages": [message]}
def parse_message(self, state: AgentState) -> AgentState:
print("Parsing transaction message...")
message = state["messages"][0]#.content
system = """
You are a helpful assistant skilled at parsing transaction messages and providing structured responses.
"""
human = "Categorize the transaction message and provide the output in a structed format: {topic}"
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])
chain = prompt | self.model.with_structured_output(TransactionParser)
result = chain.invoke({"topic": message})
print("Parsing transaction message completed.")
return {"messages": [result]}
def write_message(self, state: AgentState) -> AgentState:
print("Writing transaction message to Google Sheets...")
result = state["messages"][-1]
SCOPES = ('https://www.googleapis.com/auth/spreadsheets', 'https://www.googleapis.com/auth/drive')
service_account_info = json.loads(GOOGLESHEETS_CREDENTIALS)
my_credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
client = pygsheets.authorize(custom_credentials=my_credentials)
worksheet = client.open_by_url(sheet_url)
wk = worksheet[0]
# Get number of rows in the worksheet
df = wk.get_as_df(start='A1', end='G999')
nrows = df.shape[0]
wk.update_value(f'A{nrows+2}', result.amount)
wk.update_value(f'B{nrows+2}', result.dr_or_cr)
wk.update_value(f'C{nrows+2}', result.receiver)
wk.update_value(f'D{nrows+2}', result.category)
wk.update_value(f'E{nrows+2}', result.transaction_date)
wk.update_value(f'F{nrows+2}', result.transaction_origin)
wk.update_value(f'G{nrows+2}', state["messages"][0])
print("Writing transaction message to Google Sheets completed.")
return {"messages": ["Transaction Completed"]}
def check_txn_and_decide(self, state: AgentState):
try:
result = json.loads(state['messages'][-1].content)['classification']
except json.JSONDecodeError:
result = state['messages'][-1].content.strip()
return result == "Transaction"
prompt = """You are a smart assistant adept at classifying different messages. \
You will be penalized heavily for incorrect classification. \
Your task is to classify the message into one of the following categories: \
Transaction, OTP, Promotional, Scheduled. \
Output the classification in a structured format like below. \
{"classification": "OTP"} \
"""
app = FastAPI()
@app.get("/")
def greetings():
return {"message": "Hello, this is a transaction bot. Please send a POST request to /write_message with the transaction data."}
@app.post("/write_message")
def write_message(data: dict):
message = data['message']
model = ChatGroq(temperature=1, groq_api_key=GROQ_API_KEY, model_name=MODEL)
# model = ChatOllama(model="gemma3:1b", temperature=1)
transaction_bot = Agent(model, system=prompt)
transaction_bot.graph.invoke({"messages": [message]})
return {"message": "Transaction completed successfully"}
@app.get("/ask")
def ask(prompt: str):
model = ChatGroq(temperature=1, groq_api_key=GROQ_API_KEY, model_name=MODEL)
# model = ChatOllama(model="gemma3:1b", temperature=1)
return model.invoke(prompt)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")