Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Sun Oct 13 10:30:56 2024 | |
| @author: legalchain | |
| """ | |
| from typing import Literal, Optional, List, Union, Any | |
| from langchain_openai import ChatOpenAI | |
| import pandas as pd | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langgraph.graph import END, StateGraph, START | |
| from langchain_core.output_parsers import StrOutputParser | |
| from pydantic import BaseModel, Field | |
| from models import NatureJugement | |
| from prompts import df_prompt, feed_back_prompt, reflection_prompt | |
| llm = ChatOpenAI(model="gpt-4o-mini") | |
| MAX_GENERATIONS = 3 | |
| MAX_ROWS: int = 200 | |
| class Query(BaseModel): | |
| query:str = Field(..., title="Requête pour filtrer les résultats du dataframe entourée avec des gullemets de type \" ") | |
| def clean_query(self): | |
| # Correction des échappements dans la chaîne de la requête | |
| corrected_query = self.query.replace("\\'", "\\'") | |
| # Extraire la condition à l'intérieur des crochets | |
| import re | |
| condition = re.search(r"df\[(.*)\]", corrected_query).group(1) | |
| return condition | |
| class GradeResults(BaseModel): | |
| binary_score: Literal["yes", "no"] = Field( | |
| description="Les résultats sont satisfaisants -> 'yes' ou il y une erreur ou pas de résultats ou les résultats sont améliorables -> 'no'" | |
| ) | |
| class GraphState(BaseModel): | |
| df : Any | |
| df_head:str | |
| instructions: Optional[str] = None | |
| nature_jugement: List = ', '.join([e.value for e in NatureJugement]) | |
| region:str = '' | |
| dep:str = '' | |
| query: Optional[str] = None | |
| results :Union[str, List[str]] = [] | |
| query_feedbacks: Optional[str] = None | |
| results_feedbacks: bool = None | |
| generation_num: int = 0 | |
| retrieval_num: int = 0 | |
| search_mode: Literal["vectorstore", "websearch", "QA_LM"] = "QA_LM" | |
| error_query: Optional[Any] = "" | |
| error_results: Optional[Any] = "" | |
| truncated: bool = False | |
| # Méthode pour récupérer le DataFrame | |
| def get_df(self) -> pd.DataFrame: | |
| return pd.read_json(self.df) | |
| # Surcharger l'initialisation pour créer les champs 'region' et 'dep' | |
| def __init__(self, **data): | |
| super().__init__(**data) | |
| # Générer les chaînes pour les régions et départements | |
| distinct_regions = self.df['region_nom_officiel'].dropna().unique().tolist() | |
| distinct_departements = self.df['departement_nom_officiel'].dropna().unique().tolist() | |
| # Convertir en chaînes séparées par des virgules | |
| self.region = ', '.join(distinct_regions) | |
| self.dep = ', '.join(distinct_departements) | |
| def generate_query_node(state: GraphState): | |
| prompt = ChatPromptTemplate.from_messages(messages = df_prompt) | |
| generate_df_query = prompt | llm.with_structured_output( | |
| Query, | |
| include_raw=True, # permet de checker les erreurs en sortie | |
| ) | |
| # TODO : Ajouter le retour erreur de parse_error | |
| try : | |
| query_generate = generate_df_query.invoke({ | |
| 'df_head' : state.df_head, | |
| 'instructions' : state.instructions, | |
| 'feedback' : state.query_feedbacks, | |
| 'error' : state.error_query, | |
| 'nature_jugement' : state.nature_jugement, | |
| 'dep' : state.dep, | |
| 'region': state.region | |
| }) | |
| query_final = query_generate['parsed'].clean_query() | |
| return { | |
| "query": query_final, | |
| "error_query" : "" # si il ya une erreur cela remet le compteur à zéro | |
| } | |
| except Exception as e: | |
| return {'error_query' : e} | |
| def evaluate_query_node(state:GraphState): | |
| if state.error_query != "": | |
| return "Il y a une erreur dans la requête. Je me suis sûrement trompé. Veuillez réessayer." | |
| else: | |
| return "ok" | |
| def generate_results_node(state:GraphState): | |
| try : | |
| query = state.query | |
| print("query ", query) | |
| print('je suis dans generate', type(state.df)) | |
| query = eval(query, {"df": state.df}) | |
| new_df = state.df[query] | |
| print("new_df", new_df.empty) | |
| if new_df.empty: | |
| return { | |
| "generation_num": state.generation_num + 1} | |
| elif len(new_df)> MAX_ROWS: | |
| return {'results' : new_df.head(MAX_ROWS).to_json(orient='records'), | |
| "generation_num": state.generation_num + 1, | |
| "truncated": True | |
| } | |
| else: | |
| return {'results' : new_df.to_json(orient='records'), | |
| "generation_num": state.generation_num + 1, | |
| } | |
| except Exception as e : | |
| return {'error_results' : e, | |
| "generation_num": state.generation_num + 1} | |
| def evaluate_results_node(state:GraphState): | |
| prompt_eval = ChatPromptTemplate.from_messages(messages=reflection_prompt) | |
| generate_eval = prompt_eval | llm.with_structured_output( | |
| GradeResults, | |
| include_raw=False, # permet de checker les erreurs en sortie | |
| ) | |
| evaluation = generate_eval.invoke({'df_head' : state.df_head, | |
| 'results' :state.results, | |
| 'instructions' : state.instructions}) | |
| if state.generation_num > MAX_GENERATIONS: | |
| return "max_generation_reached" | |
| return evaluation.binary_score | |
| def query_feedback_node(state: GraphState): | |
| prompt_feed_back = ChatPromptTemplate.from_messages(messages=feed_back_prompt) | |
| query_feedback_chain = prompt_feed_back| llm |StrOutputParser() | |
| feedback = query_feedback_chain.invoke({ | |
| "df_head" : state.df_head, | |
| "instructions": state.instructions, | |
| "results": state.results, | |
| "query": state.query | |
| }) | |
| feedback = f"Evaluation de la recherche : {feedback}" | |
| print(feedback) | |
| return {"query_feedbacks": feedback} |