Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Sun Sep 22 15:43:16 2024 | |
| @author: Raphaël d'Assignies (rdassignies@protonmail.ch) | |
| """ | |
| import json | |
| from typing import Literal, Optional, List, Union, Any | |
| from langchain_openai import ChatOpenAI | |
| import pandas as pd | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langgraph.graph import END, StateGraph, START | |
| from langchain_core.output_parsers import StrOutputParser | |
| from pydantic import BaseModel, Field | |
| from models import NatureJugement | |
| from nodes import (GradeResults, GraphState, generate_query_node, | |
| generate_results_node, query_feedback_node, | |
| evaluate_query_node, evaluate_results_node) | |
| import streamlit as st | |
| # Instanciate pipeline | |
| pipeline = StateGraph(GraphState) | |
| pipeline.add_node('generate_query', generate_query_node) | |
| pipeline.add_node('generate_results', generate_results_node) | |
| pipeline.add_node('query_feedback', query_feedback_node) | |
| # Only query | |
| #pipeline.add_edge(START,'generate_query') | |
| #pipeline.add_edge('generate_query', generate_query_node) | |
| #pipeline.add_edge('generate_query', END) | |
| # Full scenario | |
| pipeline.add_edge(START,'generate_query') | |
| pipeline.add_conditional_edges( | |
| 'generate_query', | |
| evaluate_query_node, | |
| {'error_query' : 'generate_query', | |
| 'ok' : 'generate_results' | |
| }) | |
| pipeline.add_conditional_edges( | |
| 'generate_results', | |
| evaluate_results_node, | |
| { | |
| "yes": END, | |
| "no": 'query_feedback', | |
| "max_generation_reached": END | |
| } | |
| ) | |
| # Création du graph | |
| graph = pipeline.compile() | |
| # Load le dataframe | |
| df = pd.read_json('bodacc.json', orient='table') | |
| # Initialise le dictionnaire | |
| inputs = { | |
| 'df_head': df.head().to_csv(), | |
| 'df': df | |
| } | |
| # Créé un dictionnaire des sorties vide | |
| outputs = {} | |
| # Titre de l'application | |
| st.title("Chat with BODACC !") | |
| # Message d'avertissement | |
| warning_message = (f"Cet outil, purement pédagogique, est basé sur des données réelles allant de {df['dateparution'].min()} " | |
| f"à {df['dateparution'].max()}, et permet d'interroger le BODACC en langage naturel. Compte tenu de la variabilité des modèles, nous ne pouvons pas garantir la fiabilité des réponses.") | |
| st.warning(warning_message) | |
| # Interface utilisateur pour entrer la requête | |
| user_query = st.text_input("Entrez votre requête:", "Trouve moi les restaurants à reprendre en Bretagne dans les 30 derniers jours") | |
| # Afficher les résultats avec Streamlit | |
| inputs["instructions"] = user_query | |
| # Afficher un bouton pour démarrer la recherche | |
| if st.button("Lancer la recherche"): | |
| config = {"configurable": {"thread_id": "2"}} | |
| # Étape 1 : Afficher le message "Je réfléchis..." | |
| st.write("Je réfléchis...") | |
| # Stream des résultats au fur et à mesure | |
| with st.spinner('Recherche en cours...'): | |
| for output in graph.stream(inputs, stream_mode='values', debug=False): | |
| # Ajouter les résultats au dictionnaire outputs | |
| for k, v in output.items(): | |
| if k not in outputs: | |
| outputs[k] = [] | |
| outputs[k].append(v) | |
| # Ne pas afficher les messages pour les clés non pertinentes (comme error_query) | |
| if 'query' in output and len(output['query'])>0: | |
| st.write(f"query : {output['query']}") | |
| #st.write(outputs.get('query_feedbacks', 'pas de feedback')) | |
| #st.write(outputs.get('results_feedbacks', 'pas de resultfeedback')) | |
| if "results" in output and len(output["results"]) > 0: | |
| records = json.loads(output['results']) | |
| st.write(f"Résultats intermédiaires trouvés : {len(records)} résultats jusqu'à présent.") | |
| # Après la fin du traitement | |
| if "results" in outputs and len(outputs["results"]) > 0: | |
| # Agréger tous les résultats accumulés | |
| all_results = [] | |
| for res in outputs["results"]: | |
| json_data = json.loads(res) # Convertir chaque ensemble de résultats en JSON | |
| all_results.extend(json_data) # Accumuler tous les résultats | |
| results_df = pd.DataFrame(all_results) # Créer un DataFrame avec tous les résultats accumulés | |
| # Afficher un aperçu des résultats (jusqu'à 5 premiers) | |
| num_results = len(results_df) | |
| st.write(f"J'ai trouvé {num_results} résultats.") | |
| if num_results > 0: | |
| preview_count = min(5, num_results) # Gérer le cas où il y a moins de 5 résultats | |
| st.write(f"Voici un aperçu des {preview_count} premiers résultats :") | |
| st.write(results_df.head(preview_count)) | |
| trunc = outputs.get('truncated', 'pas de traunc') | |
| if trunc[0] == True: | |
| st.warning("Les résultats de votre recherche ont été tronqués car celle-ci était trop large ! ") | |
| # Convertir tous les résultats en CSV | |
| csv = results_df.to_csv(index=False) | |
| # Ajouter un bouton pour télécharger tous les résultats | |
| st.download_button( | |
| label="Télécharger le résultat complet au format CSV", | |
| data=csv, | |
| file_name="results.csv", | |
| mime="text/csv" | |
| ) | |
| else: | |
| # Si aucun résultat n'est trouvé | |
| st.write("Aucun résultat trouvé.") | |