qstn_gui / src /pages /04_Final_Overview.py
MaxBenKre
updated gui
5bcad58
import streamlit as st
from gui_elements.paginator import paginator
from gui_elements.stateful_widget import StatefulWidgets
from gui_elements.output_manager import st_capture, TqdmToStreamlit
import io
import queue
import time
import threading
import asyncio
import pandas as pd
import logging
from contextlib import redirect_stderr, redirect_stdout
from qstn.parser.llm_answer_parser import raw_responses
from qstn.utilities.constants import QuestionnairePresentation
from qstn.utilities.utils import create_one_dataframe
from qstn.survey_manager import (
conduct_survey_sequential,
conduct_survey_battery,
conduct_survey_single_item,
)
from streamlit.runtime.scriptrunner import add_script_run_ctx
from openai import AsyncOpenAI
# Set OpenAI's API key and API base to use vLLM's API server.
if "questionnaires" not in st.session_state:
st.error(
"You need to first upload a questionnaire and the population you want to survey."
)
st.stop()
disabled = True
else:
disabled = False
@st.cache_data
def create_stateful_widget() -> StatefulWidgets:
return StatefulWidgets()
state = create_stateful_widget()
current_index = paginator(st.session_state.questionnaires, "overview_page")
questionnaires = st.session_state.questionnaires[current_index]
col_llm, col_prompt_display = st.columns(2)
with col_llm:
st.subheader("⚙️ Inference Parameters")
with st.container(border=True):
st.subheader("Core Settings")
model_name = state.create(
st.text_input,
"model_name",
"Model Name",
# initial_value="meta-llama/Llama-3.1-70B-Instruct",
# placeholder="meta-llama/Llama-3.1-70B-Instruct",
disabled=True,
help="The model to use for the inference call.",
)
temperature = state.create(
st.slider,
"temperature",
"Temperature",
min_value=0.0,
max_value=2.0,
step=0.01,
initial_value=1.0,
disabled=True,
help="Controls randomness. Lower values are more deterministic and less creative.",
)
max_tokens = state.create(
st.number_input,
"max_tokens",
"Max Tokens",
initial_value=1024,
min_value=1,
disabled=True,
help="The maximum number of tokens to generate in the completion.",
)
top_p = state.create(
st.slider,
"top_p",
"Top P",
min_value=0.0,
max_value=1.0,
step=0.01,
initial_value=1.0,
disabled=True,
help="Controls nucleus sampling. The model considers tokens with top_p probability mass.",
)
seed = state.create(
st.number_input,
"seed",
"Seed",
initial_value=42,
min_value=0,
disabled=True,
help="A specific seed for reproducibility of results.",
)
with st.expander("Advanced Inference Settings (JSON)"):
advanced_inference_params_str = state.create(
st.text_area,
"advanced_inference_params_str",
"JSON for other inference parameters",
initial_value="",
# placeholder='{\n "stop": ["\\n", " Human:"],\n "presence_penalty": 0\n}',
height=150,
disabled=True,
help='Enter any other valid inference parameters like "stop", "logit_bias", or "frequency_penalty" as a JSON object.',
)
with col_prompt_display:
st.subheader("📄 Live Preview")
# Survey method selector
survey_method_options = {
"Single item": ("single_item", QuestionnairePresentation.SINGLE_ITEM),
"Battery": ("battery", QuestionnairePresentation.BATTERY),
"Sequential": ("sequential", QuestionnairePresentation.SEQUENTIAL),
}
survey_method_display = state.create(
st.selectbox,
"survey_method",
"Questionnaire Method",
options=list(survey_method_options.keys()),
initial_value="Single item",
help="Choose how to conduct the questionnaire: Single item (one at a time), Battery (all questions together), or Sequential (with conversation history)."
)
# Get the method name and questionnaire type from selection
selected_method_name, selected_questionnaire_type = survey_method_options[survey_method_display]
with st.container(border=True):
# For single item mode, show multiple previews (up to 3 items)
if selected_questionnaire_type == QuestionnairePresentation.SINGLE_ITEM:
num_questions = len(questionnaires._questions)
num_previews = min(3, num_questions) # Show up to 3 previews
if num_previews > 1:
st.write(f"**Preview of first {num_previews} items:**")
else:
st.write("**Preview:**")
for i in range(num_previews):
if num_previews > 1:
st.write(f"**Item {i+1}:**")
current_system_prompt, current_prompt = questionnaires.get_prompt_for_questionnaire_type(
selected_questionnaire_type,
item_id=i
)
current_system_prompt = current_system_prompt.replace("\n", " \n")
current_prompt = current_prompt.replace("\n", " \n")
st.write(current_system_prompt)
st.write(current_prompt)
# Add separator between items (except for the last one)
if i < num_previews - 1:
st.divider()
else:
# For battery and sequential, show single preview as before
current_system_prompt, current_prompt = questionnaires.get_prompt_for_questionnaire_type(selected_questionnaire_type)
current_system_prompt = current_system_prompt.replace("\n", " \n")
current_prompt = current_prompt.replace("\n", " \n")
st.write(current_system_prompt)
st.write(current_prompt)
if st.button("Confirm and Run Questionnaire", type="primary", use_container_width=True):
st.write("Starting inference...")
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = AsyncOpenAI(**st.session_state.client_config)
inference_config = st.session_state.inference_config.copy()
model_name = inference_config.pop("model")
progress_text = st.empty()
log_queue = queue.Queue()
result_queue = queue.Queue()
class QueueWriter:
def __init__(self, q):
self.q = q
def write(self, message):
if message.strip():
self.q.put(message)
def flush(self):
# This function is needed to match the file-like object interface
# but we don't need to do anything here.
pass
# Helper function for asyncronous runs
def run_async_in_thread(
result_q, client, questionnaires, model_name, survey_method_name, **inference_config
):
queue_writer = QueueWriter(log_queue)
# We need to redirect the output to a queue, as streamlit does not support multithreading
# API concurrency should be configurable in the GUI
try:
with redirect_stderr(queue_writer):
# Select the appropriate survey method based on user choice
if survey_method_name == "single_item":
survey_func = conduct_survey_single_item
elif survey_method_name == "battery":
survey_func = conduct_survey_battery
elif survey_method_name == "sequential":
survey_func = conduct_survey_sequential
else:
survey_func = conduct_survey_single_item # Default fallback
result = survey_func(
client,
llm_prompts=questionnaires,
client_model_name=model_name,
api_concurrency=100,
**inference_config,
)
except Exception as e:
result = e
st.error(e)
finally:
result_q.put(result)
while not log_queue.empty():
log_queue.get()
while not result_queue.empty():
result_queue.get()
# Get the selected survey method
survey_method_display = st.session_state.get("survey_method", "Single item")
survey_method_options = {
"Single item": ("single_item", QuestionnairePresentation.SINGLE_ITEM),
"Battery": ("battery", QuestionnairePresentation.BATTERY),
"Sequential": ("sequential", QuestionnairePresentation.SEQUENTIAL),
}
selected_method_name, _ = survey_method_options.get(survey_method_display, ("single_item", QuestionnairePresentation.SINGLE_ITEM))
thread = threading.Thread(
target=run_async_in_thread,
args=(result_queue, client, st.session_state.questionnaires, model_name, selected_method_name),
kwargs=inference_config,
)
thread.start()
all_questions_placeholder = st.empty()
progress_placeholder = st.empty()
while thread.is_alive():
try:
# Here we can write directly to the UI, as it is the main thread
# TQDM uses carriage returns (\r) to animate in the console, we only show clear lines
log_message = log_queue.get_nowait()
# This is quite a hacky solution for now, we should adjust QSTN to make the messages clearly parsable.
if "[A" not in log_message and "Processing Prompts" not in log_message:
all_questions_placeholder.text(log_message.strip().replace("\r", ""))
elif "Processing Prompts" in log_message:
progress_placeholder.text(log_message.strip().replace("\r", ""))
except queue.Empty:
pass
time.sleep(0.1)
thread.join()
all_questions_placeholder.empty()
progress_placeholder.empty()
try:
final_output = result_queue.get_nowait()
except queue.Empty:
st.error("Could not retrieve result from the asynchronous task.")
st.success("Finished inferencing!")
responses = raw_responses(final_output)
df = create_one_dataframe(responses)
# Store the dataframe in session state for saving later
st.session_state.results_dataframe = df
st.session_state.inference_completed = True
st.dataframe(df)
# Show save button if inference is completed
if "inference_completed" in st.session_state and st.session_state.inference_completed:
st.divider()
st.subheader("💾 Save Results")
# Text input for filename
if "save_filename" not in st.session_state:
st.session_state.save_filename = "questionnaire_results.csv"
save_filename = st.text_input(
"Save File",
value=st.session_state.save_filename,
key="save_filename_input",
help="Enter the filename for the results. Should be a CSV file (e.g., results.csv)."
)
# Ensure filename ends with .csv
if save_filename and not save_filename.endswith('.csv'):
save_filename = save_filename + '.csv'
# Convert dataframe to CSV string for download
csv = st.session_state.results_dataframe.to_csv(index=False)
st.download_button(
label="Save Results",
data=csv,
file_name=save_filename if save_filename else "questionnaire_results.csv",
mime="text/csv",
type="primary",
use_container_width=True,
help="Click to save the results to your computer. You can choose the directory and filename in the save dialog."
)