Spaces:
Paused
Paused
| import re | |
| import uuid | |
| import pandas as pd | |
| import streamlit as st | |
| import re | |
| import matplotlib.pyplot as plt | |
| import subprocess | |
| import sys | |
| import io | |
| import inspect | |
| from utils.default_values import get_system_prompt, get_guidelines_dict | |
| from utils.epfl_meditron_utils import get_llm_response, gptq_model_options | |
| from utils.openai_utils import get_available_engines, get_search_query_type_options | |
| from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay | |
| from sklearn.metrics import classification_report | |
| POC_VERSION = "0.1.1" | |
| st.set_page_config(page_title='Medgate Whisper PoC', page_icon='public/medgate.png') | |
| def display_streamlit_sidebar(): | |
| st.sidebar.title("Local LLM PoC " + str(POC_VERSION)) | |
| st.sidebar.write('**Parameters**') | |
| form = st.sidebar.form("config_form", clear_on_submit=True) | |
| model_name_or_path = form.selectbox("Select model", gptq_model_options(), index=st.session_state["model_index"]) | |
| model_name_or_path_other = form.text_input('Or input any GPTQ model', value=st.session_state["model_name_or_path_other"]) | |
| temperature = form.slider(label="Temperature", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["temperature"]) | |
| do_sample = form.checkbox('do_sample', value=st.session_state["do_sample"]) | |
| top_p = form.slider(label="top_p", min_value=0.0, max_value=1.0, step=0.01, value=st.session_state["top_p"]) | |
| top_k = form.slider(label="top_k", min_value=1, max_value=1000, step=1, value=st.session_state["top_k"]) | |
| max_new_tokens = form.slider(label="max_new_tokens", min_value=32, max_value=4096, step=1, value=st.session_state["max_new_tokens"]) | |
| repetition_penalty = form.slider(label="repetition_penalty", min_value=0.0, max_value=5.0, step=0.01, value=st.session_state["repetition_penalty"]) | |
| submitted = form.form_submit_button("Start session") | |
| if submitted: | |
| print('Parameters updated...') | |
| st.session_state['session_started'] = True | |
| st.session_state["session_events"] = [] | |
| if len(model_name_or_path_other) > 0: | |
| st.session_state["model_name"] = model_name_or_path_other | |
| st.session_state["model_name_or_path_other"] = model_name_or_path_other | |
| else: | |
| st.session_state["model_name"] = model_name_or_path | |
| st.session_state["model_index"] = gptq_model_options().index(model_name_or_path) | |
| st.session_state["model_name_or_path"] = model_name_or_path | |
| st.session_state["temperature"] = temperature | |
| st.session_state["do_sample"] = do_sample | |
| st.session_state["top_p"] = top_p | |
| st.session_state["top_k"] = top_k | |
| st.session_state["max_new_tokens"] = max_new_tokens | |
| st.session_state["repetition_penalty"] = repetition_penalty | |
| st.rerun() | |
| def init_session_state(): | |
| print('init_session_state()') | |
| st.session_state['session_started'] = False | |
| st.session_state["session_events"] = [] | |
| st.session_state["model_name_or_path"] = "TheBloke/meditron-7B-GPTQ" | |
| st.session_state["model_name_or_path_other"] = "" | |
| st.session_state["model_index"] = 0 | |
| st.session_state["temperature"] = 0.01 | |
| st.session_state["do_sample"] = True | |
| st.session_state["top_p"] = 0.95 | |
| st.session_state["top_k"] = 40 | |
| st.session_state["max_new_tokens"] = 4096 | |
| st.session_state["repetition_penalty"] = 1.1 | |
| st.session_state["system_prompt"] = "You are a medical expert that provides answers for a medically trained audience" | |
| st.session_state["prompt"] = "" | |
| st.session_state["llm_messages"] = [] | |
| def display_session_overview(): | |
| st.subheader('History of LLM queries') | |
| st.write(st.session_state["llm_messages"]) | |
| st.subheader("Session costs overview") | |
| df_session_overview = pd.DataFrame.from_dict(st.session_state["session_events"]) | |
| st.write(df_session_overview) | |
| if "prompt_tokens" in df_session_overview: | |
| prompt_tokens = df_session_overview["prompt_tokens"].sum() | |
| st.write("Prompt tokens: " + str(prompt_tokens)) | |
| prompt_cost = df_session_overview["prompt_cost_chf"].sum() | |
| st.write("Prompt CHF: " + str(prompt_cost)) | |
| completion_tokens = df_session_overview["completion_tokens"].sum() | |
| st.write("Completion tokens: " + str(completion_tokens)) | |
| completion_cost = df_session_overview["completion_cost_chf"].sum() | |
| st.write("Completion CHF: " + str(completion_cost)) | |
| completion_cost = df_session_overview["total_cost_chf"].sum() | |
| st.write("Total costs CHF: " + str(completion_cost)) | |
| total_time = df_session_overview["response_time"].sum() | |
| st.write("Total compute time (ms): " + str(total_time)) | |
| def get_prompt_format(model_name): | |
| formatted_text = "" | |
| if model_name == "TheBloke/Llama-2-13B-chat-GPTQ" or model_name== "TheBloke/Llama-2-7B-Chat-GPTQ": | |
| formatted_text = '''[INST] <<SYS>> | |
| {system_message} | |
| <</SYS>> | |
| {prompt}[/INST] | |
| ''' | |
| if model_name == "TheBloke/meditron-7B-GPTQ" or model_name == "TheBloke/meditron-70B-GPTQ": | |
| formatted_text = '''<|im_start|>system | |
| {system_message}<|im_end|> | |
| <|im_start|>user | |
| {prompt}<|im_end|> | |
| <|im_start|>assistant | |
| ''' | |
| return inspect.cleandoc(formatted_text) | |
| def format_prompt(template, system_message, prompt): | |
| if template == "": | |
| return f"{system_message} {prompt}" | |
| return template.format(system_message=system_message, prompt=prompt) | |
| def display_llm_output(): | |
| st.header("LLM") | |
| form = st.form('llm') | |
| prompt_format_str = get_prompt_format(st.session_state["model_name_or_path"]) | |
| prompt_format = form.text_area('Prompt format', value=prompt_format_str, height=170) | |
| system_prompt = form.text_area('System message', value=st.session_state["system_prompt"], height=170) | |
| prompt = form.text_area('Prompt', value=st.session_state["prompt"], height=170) | |
| submitted = form.form_submit_button('Submit') | |
| if submitted: | |
| st.session_state["system_prompt"] = system_prompt | |
| st.session_state["prompt"] = prompt | |
| formatted_prompt = format_prompt(prompt_format, system_prompt, prompt) | |
| print(f"Formatted prompt: {format_prompt}") | |
| llm_response = get_llm_response( | |
| st.session_state["model_name"], | |
| st.session_state["temperature"], | |
| st.session_state["do_sample"], | |
| st.session_state["top_p"], | |
| st.session_state["top_k"], | |
| st.session_state["max_new_tokens"], | |
| st.session_state["repetition_penalty"], | |
| formatted_prompt) | |
| st.write(llm_response) | |
| def main(): | |
| print('Running Local LLM PoC Streamlit app...') | |
| session_inactive_info = st.empty() | |
| if "session_started" not in st.session_state or not st.session_state["session_started"]: | |
| init_session_state() | |
| display_streamlit_sidebar() | |
| else: | |
| display_streamlit_sidebar() | |
| session_inactive_info.empty() | |
| display_llm_output() | |
| display_session_overview() | |
| if __name__ == '__main__': | |
| main() | |