File size: 6,787 Bytes
e01d6d5
2101fac
5fb65b7
 
 
 
 
 
 
 
 
 
 
 
 
5d0417f
 
5fb65b7
5d0417f
5fb65b7
5d0417f
 
5fb65b7
 
 
 
037e9c2
5fb65b7
 
 
037e9c2
5fb65b7
 
 
 
 
 
 
 
 
5d0417f
 
5fb65b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e04da09
5fb65b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e04da09
5fb65b7
 
 
 
 
e04da09
5fb65b7
e04da09
5fb65b7
 
 
 
 
 
 
 
e04da09
5fb65b7
 
 
e04da09
5fb65b7
 
 
5d0417f
 
5fb65b7
037e9c2
5fb65b7
 
 
 
5d0417f
 
5fb65b7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from fastapi import FastAPI
import uvicorn
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import pygsheets
import json
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
from langchain_core.messages import SystemMessage, HumanMessage, AnyMessage
from langchain_ollama import ChatOllama
from langgraph.pregel import RetryPolicy
import json
from google.oauth2 import service_account
import os
from langchain_groq import ChatGroq
import groq

sheet_url = os.getenv("SHEET_URL")
GOOGLESHEETS_CREDENTIALS = os.getenv("GOOGLESHEETS_CREDENTIALS")
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"

class TransactionParser(BaseModel):
    """This Pydantic class is used to parse the transaction message."""

    amount: str = Field(description="The amount of the transaction strictly in decimal format. If the transaction is a credit or a reversal, then include negative sign. Otherwise it should always be positive. DO not insert currency symbol.", example="123.45")
    dr_or_cr: str = Field(description="Identify if the transaction was debit (spent) or credit (received). Strictly choose one of the values - Debit or Credit")
    receiver: str = Field(description="The recipient of the transaction. Identify the Merchant Name from the value.")
    category: str = Field(description="The category of the transaction. The category of the transaction is linked to the Merchant Name. Strictly choose from one the of values - Shopping,EMI,Education,Miscellaneous,Grocery,Utility,House Help,Travel,Transport")
    transaction_date: str = Field(description="The date of the transaction strictly in yyyy-mm-dd format. If the year is not provided then use current year.")
    transaction_origin: str = Field(description="The origin of the transaction. Provide the card or account number as well.")

class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], operator.add]

class Agent:
    def __init__(self, model, system=""):
        self.system = system
        graph = StateGraph(AgentState)
        graph.add_node("classify_txn_type", self.classify_txn_type, retry=RetryPolicy(retry_on=[groq.BadRequestError], max_attempts=2))
        graph.add_node("parse_message", self.parse_message, retry=RetryPolicy(retry_on=[groq.BadRequestError], max_attempts=2))
        graph.add_node("write_message", self.write_message)
        graph.add_conditional_edges(
            "classify_txn_type",
            self.check_txn_and_decide,
            {True: "parse_message", False: END}
            )
        graph.add_edge("parse_message", "write_message")
        graph.add_edge("write_message", END)
        graph.set_entry_point("classify_txn_type")
        self.graph = graph.compile()
        self.model = model

    def classify_txn_type(self, state: AgentState) -> AgentState:
        print("Classifying transaction type...")
        messages = state["messages"]
        if self.system:
            messages = [SystemMessage(content=self.system)] + messages

        message = self.model.invoke(messages)
        print("Classifying transaction type completed.")
        return {"messages": [message]}
    
    def parse_message(self, state: AgentState) -> AgentState:
        print("Parsing transaction message...")
        message = state["messages"][0]#.content
        system = """
        You are a helpful assistant skilled at parsing transaction messages and providing structured responses.
        """
        human = "Categorize the transaction message and provide the output in a structed format: {topic}"

        prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])
        chain = prompt | self.model.with_structured_output(TransactionParser)
        result = chain.invoke({"topic": message})
        print("Parsing transaction message completed.")        
        return {"messages": [result]}

    def write_message(self, state: AgentState) -> AgentState:
        print("Writing transaction message to Google Sheets...")
        result = state["messages"][-1]
        SCOPES = ('https://www.googleapis.com/auth/spreadsheets', 'https://www.googleapis.com/auth/drive')
        service_account_info = json.loads(GOOGLESHEETS_CREDENTIALS)
        my_credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
        client = pygsheets.authorize(custom_credentials=my_credentials)
        worksheet = client.open_by_url(sheet_url)
        wk = worksheet[0]
        # Get number of rows in the worksheet
        df = wk.get_as_df(start='A1', end='G999')
        nrows = df.shape[0]
        wk.update_value(f'A{nrows+2}', result.amount)
        wk.update_value(f'B{nrows+2}', result.dr_or_cr)
        wk.update_value(f'C{nrows+2}', result.receiver)
        wk.update_value(f'D{nrows+2}', result.category)
        wk.update_value(f'E{nrows+2}', result.transaction_date)
        wk.update_value(f'F{nrows+2}', result.transaction_origin)
        wk.update_value(f'G{nrows+2}', state["messages"][0])
        print("Writing transaction message to Google Sheets completed.")
        return {"messages": ["Transaction Completed"]}
        
    def check_txn_and_decide(self, state: AgentState):
        try:
            result = json.loads(state['messages'][-1].content)['classification']
        except json.JSONDecodeError:
            result = state['messages'][-1].content.strip()

        return result == "Transaction"
    
prompt = """You are a smart assistant adept at classifying different messages. \
You will be penalized heavily for incorrect classification. \
Your task is to classify the message into one of the following categories: \
Transaction, OTP, Promotional, Scheduled. \
Output the classification in a structured format like below. \
{"classification": "OTP"} \
"""
app = FastAPI()

@app.get("/")
def greetings():
    return {"message": "Hello, this is a transaction bot. Please send a POST request to /write_message with the transaction data."}

@app.post("/write_message")
def write_message(data: dict):
    message = data['message']
    model = ChatGroq(temperature=1, groq_api_key=GROQ_API_KEY, model_name=MODEL)
    # model = ChatOllama(model="gemma3:1b", temperature=1)
    transaction_bot = Agent(model, system=prompt)
    transaction_bot.graph.invoke({"messages": [message]})
    return {"message": "Transaction completed successfully"}

@app.get("/ask")
def ask(prompt: str):
    model = ChatGroq(temperature=1, groq_api_key=GROQ_API_KEY, model_name=MODEL)
    # model = ChatOllama(model="gemma3:1b", temperature=1)
    return model.invoke(prompt)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")