Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import io | |
| import sqlite3 | |
| # Initialize the InferenceClient with the specified model | |
| client = InferenceClient(model="HuggingFaceH4/zephyr-7b-beta") | |
| # Specify the path to your CSV file here | |
| csv_file_path = 'Movies.csv' | |
| # Load dataset into a dataframe | |
| df = pd.read_csv(csv_file_path) | |
| # Function to generate SQL queries | |
| def generate_sql_query(prompt, table_metadata): | |
| input_text = f"Generate an SQL query for the table with the following structure: {table_metadata}. Prompt: {prompt}" | |
| response = "" | |
| for message in client.chat_completion( | |
| messages=[{"role": "system", "content": input_text}], | |
| max_tokens=512, | |
| stream=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| ): | |
| token = message.choices[0].delta.content | |
| response += token | |
| return response | |
| # Function to execute SQL query on the dataframe | |
| def execute_query(df, query): | |
| try: | |
| with sqlite3.connect(':memory:') as conn: | |
| df.to_sql('data', conn, index=False, if_exists='replace') | |
| result_df = pd.read_sql_query(query, conn) | |
| return result_df | |
| except Exception as e: | |
| return str(e) | |
| # Function to create a plot from the result dataframe | |
| def create_plot(df): | |
| fig, ax = plt.subplots() | |
| df.plot(ax=ax) | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png') | |
| buf.seek(0) | |
| return buf | |
| # Gradio function to handle user input and interaction | |
| def respond(user_prompt, system_message, max_tokens, temperature, top_p): | |
| table_metadata = str(df.dtypes.to_dict()) | |
| sql_query = generate_sql_query(user_prompt, table_metadata) | |
| result_df = execute_query(df, sql_query) | |
| if isinstance(result_df, str): | |
| return sql_query, result_df, None # Return the error message | |
| plot = create_plot(result_df) | |
| return sql_query, result_df.head().to_html(), plot | |
| # Gradio UI components | |
| def create_demo(): | |
| with gr.Blocks() as demo: | |
| user_prompt = gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="User Prompt") | |
| system_message = gr.Textbox(value="You are an AI assistant that generates SQL queries based on user prompts.", label="System message") | |
| max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens") | |
| temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") | |
| output_sql_query = gr.Textbox(label="Generated SQL Query") | |
| output_result_df = gr.HTML(label="Query Result") | |
| output_plot = gr.Image(label="Result Plot") | |
| submit_btn = gr.Button("Submit") | |
| submit_btn.click(respond, inputs=[user_prompt, system_message, max_tokens, temperature, top_p], outputs=[output_sql_query, output_result_df, output_plot]) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch() | |