Spaces:
Running
Running
File size: 6,787 Bytes
e01d6d5 2101fac 5fb65b7 5d0417f 5fb65b7 5d0417f 5fb65b7 5d0417f 5fb65b7 037e9c2 5fb65b7 037e9c2 5fb65b7 5d0417f 5fb65b7 e04da09 5fb65b7 e04da09 5fb65b7 e04da09 5fb65b7 e04da09 5fb65b7 e04da09 5fb65b7 e04da09 5fb65b7 5d0417f 5fb65b7 037e9c2 5fb65b7 5d0417f 5fb65b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from fastapi import FastAPI
import uvicorn
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
import pygsheets
import json
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
from langchain_core.messages import SystemMessage, HumanMessage, AnyMessage
from langchain_ollama import ChatOllama
from langgraph.pregel import RetryPolicy
import json
from google.oauth2 import service_account
import os
from langchain_groq import ChatGroq
import groq
sheet_url = os.getenv("SHEET_URL")
GOOGLESHEETS_CREDENTIALS = os.getenv("GOOGLESHEETS_CREDENTIALS")
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
MODEL = "meta-llama/llama-4-scout-17b-16e-instruct"
class TransactionParser(BaseModel):
"""This Pydantic class is used to parse the transaction message."""
amount: str = Field(description="The amount of the transaction strictly in decimal format. If the transaction is a credit or a reversal, then include negative sign. Otherwise it should always be positive. DO not insert currency symbol.", example="123.45")
dr_or_cr: str = Field(description="Identify if the transaction was debit (spent) or credit (received). Strictly choose one of the values - Debit or Credit")
receiver: str = Field(description="The recipient of the transaction. Identify the Merchant Name from the value.")
category: str = Field(description="The category of the transaction. The category of the transaction is linked to the Merchant Name. Strictly choose from one the of values - Shopping,EMI,Education,Miscellaneous,Grocery,Utility,House Help,Travel,Transport")
transaction_date: str = Field(description="The date of the transaction strictly in yyyy-mm-dd format. If the year is not provided then use current year.")
transaction_origin: str = Field(description="The origin of the transaction. Provide the card or account number as well.")
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], operator.add]
class Agent:
def __init__(self, model, system=""):
self.system = system
graph = StateGraph(AgentState)
graph.add_node("classify_txn_type", self.classify_txn_type, retry=RetryPolicy(retry_on=[groq.BadRequestError], max_attempts=2))
graph.add_node("parse_message", self.parse_message, retry=RetryPolicy(retry_on=[groq.BadRequestError], max_attempts=2))
graph.add_node("write_message", self.write_message)
graph.add_conditional_edges(
"classify_txn_type",
self.check_txn_and_decide,
{True: "parse_message", False: END}
)
graph.add_edge("parse_message", "write_message")
graph.add_edge("write_message", END)
graph.set_entry_point("classify_txn_type")
self.graph = graph.compile()
self.model = model
def classify_txn_type(self, state: AgentState) -> AgentState:
print("Classifying transaction type...")
messages = state["messages"]
if self.system:
messages = [SystemMessage(content=self.system)] + messages
message = self.model.invoke(messages)
print("Classifying transaction type completed.")
return {"messages": [message]}
def parse_message(self, state: AgentState) -> AgentState:
print("Parsing transaction message...")
message = state["messages"][0]#.content
system = """
You are a helpful assistant skilled at parsing transaction messages and providing structured responses.
"""
human = "Categorize the transaction message and provide the output in a structed format: {topic}"
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])
chain = prompt | self.model.with_structured_output(TransactionParser)
result = chain.invoke({"topic": message})
print("Parsing transaction message completed.")
return {"messages": [result]}
def write_message(self, state: AgentState) -> AgentState:
print("Writing transaction message to Google Sheets...")
result = state["messages"][-1]
SCOPES = ('https://www.googleapis.com/auth/spreadsheets', 'https://www.googleapis.com/auth/drive')
service_account_info = json.loads(GOOGLESHEETS_CREDENTIALS)
my_credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
client = pygsheets.authorize(custom_credentials=my_credentials)
worksheet = client.open_by_url(sheet_url)
wk = worksheet[0]
# Get number of rows in the worksheet
df = wk.get_as_df(start='A1', end='G999')
nrows = df.shape[0]
wk.update_value(f'A{nrows+2}', result.amount)
wk.update_value(f'B{nrows+2}', result.dr_or_cr)
wk.update_value(f'C{nrows+2}', result.receiver)
wk.update_value(f'D{nrows+2}', result.category)
wk.update_value(f'E{nrows+2}', result.transaction_date)
wk.update_value(f'F{nrows+2}', result.transaction_origin)
wk.update_value(f'G{nrows+2}', state["messages"][0])
print("Writing transaction message to Google Sheets completed.")
return {"messages": ["Transaction Completed"]}
def check_txn_and_decide(self, state: AgentState):
try:
result = json.loads(state['messages'][-1].content)['classification']
except json.JSONDecodeError:
result = state['messages'][-1].content.strip()
return result == "Transaction"
prompt = """You are a smart assistant adept at classifying different messages. \
You will be penalized heavily for incorrect classification. \
Your task is to classify the message into one of the following categories: \
Transaction, OTP, Promotional, Scheduled. \
Output the classification in a structured format like below. \
{"classification": "OTP"} \
"""
app = FastAPI()
@app.get("/")
def greetings():
return {"message": "Hello, this is a transaction bot. Please send a POST request to /write_message with the transaction data."}
@app.post("/write_message")
def write_message(data: dict):
message = data['message']
model = ChatGroq(temperature=1, groq_api_key=GROQ_API_KEY, model_name=MODEL)
# model = ChatOllama(model="gemma3:1b", temperature=1)
transaction_bot = Agent(model, system=prompt)
transaction_bot.graph.invoke({"messages": [message]})
return {"message": "Transaction completed successfully"}
@app.get("/ask")
def ask(prompt: str):
model = ChatGroq(temperature=1, groq_api_key=GROQ_API_KEY, model_name=MODEL)
# model = ChatOllama(model="gemma3:1b", temperature=1)
return model.invoke(prompt)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |