|
|
import streamlit as st |
|
|
from gui_elements.paginator import paginator |
|
|
from gui_elements.stateful_widget import StatefulWidgets |
|
|
from gui_elements.output_manager import st_capture, TqdmToStreamlit |
|
|
|
|
|
import io |
|
|
import queue |
|
|
import time |
|
|
import threading |
|
|
import asyncio |
|
|
import pandas as pd |
|
|
|
|
|
import logging |
|
|
|
|
|
from contextlib import redirect_stderr, redirect_stdout |
|
|
|
|
|
from qstn.parser.llm_answer_parser import raw_responses |
|
|
from qstn.utilities.constants import QuestionnairePresentation |
|
|
from qstn.utilities.utils import create_one_dataframe |
|
|
from qstn.survey_manager import ( |
|
|
conduct_survey_sequential, |
|
|
conduct_survey_battery, |
|
|
conduct_survey_single_item, |
|
|
) |
|
|
|
|
|
from streamlit.runtime.scriptrunner import add_script_run_ctx |
|
|
|
|
|
from openai import AsyncOpenAI |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "questionnaires" not in st.session_state: |
|
|
st.error( |
|
|
"You need to first upload a questionnaire and the population you want to survey." |
|
|
) |
|
|
st.stop() |
|
|
disabled = True |
|
|
else: |
|
|
disabled = False |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def create_stateful_widget() -> StatefulWidgets: |
|
|
return StatefulWidgets() |
|
|
|
|
|
|
|
|
state = create_stateful_widget() |
|
|
|
|
|
current_index = paginator(st.session_state.questionnaires, "overview_page") |
|
|
|
|
|
questionnaires = st.session_state.questionnaires[current_index] |
|
|
|
|
|
col_llm, col_prompt_display = st.columns(2) |
|
|
|
|
|
with col_llm: |
|
|
st.subheader("⚙️ Inference Parameters") |
|
|
|
|
|
with st.container(border=True): |
|
|
st.subheader("Core Settings") |
|
|
model_name = state.create( |
|
|
st.text_input, |
|
|
"model_name", |
|
|
"Model Name", |
|
|
|
|
|
|
|
|
disabled=True, |
|
|
help="The model to use for the inference call.", |
|
|
) |
|
|
|
|
|
temperature = state.create( |
|
|
st.slider, |
|
|
"temperature", |
|
|
"Temperature", |
|
|
min_value=0.0, |
|
|
max_value=2.0, |
|
|
step=0.01, |
|
|
initial_value=1.0, |
|
|
disabled=True, |
|
|
help="Controls randomness. Lower values are more deterministic and less creative.", |
|
|
) |
|
|
|
|
|
max_tokens = state.create( |
|
|
st.number_input, |
|
|
"max_tokens", |
|
|
"Max Tokens", |
|
|
initial_value=1024, |
|
|
min_value=1, |
|
|
disabled=True, |
|
|
help="The maximum number of tokens to generate in the completion.", |
|
|
) |
|
|
|
|
|
top_p = state.create( |
|
|
st.slider, |
|
|
"top_p", |
|
|
"Top P", |
|
|
min_value=0.0, |
|
|
max_value=1.0, |
|
|
step=0.01, |
|
|
initial_value=1.0, |
|
|
disabled=True, |
|
|
help="Controls nucleus sampling. The model considers tokens with top_p probability mass.", |
|
|
) |
|
|
|
|
|
seed = state.create( |
|
|
st.number_input, |
|
|
"seed", |
|
|
"Seed", |
|
|
initial_value=42, |
|
|
min_value=0, |
|
|
disabled=True, |
|
|
help="A specific seed for reproducibility of results.", |
|
|
) |
|
|
|
|
|
with st.expander("Advanced Inference Settings (JSON)"): |
|
|
advanced_inference_params_str = state.create( |
|
|
st.text_area, |
|
|
"advanced_inference_params_str", |
|
|
"JSON for other inference parameters", |
|
|
initial_value="", |
|
|
|
|
|
height=150, |
|
|
disabled=True, |
|
|
help='Enter any other valid inference parameters like "stop", "logit_bias", or "frequency_penalty" as a JSON object.', |
|
|
) |
|
|
|
|
|
|
|
|
with col_prompt_display: |
|
|
st.subheader("📄 Live Preview") |
|
|
|
|
|
|
|
|
survey_method_options = { |
|
|
"Single item": ("single_item", QuestionnairePresentation.SINGLE_ITEM), |
|
|
"Battery": ("battery", QuestionnairePresentation.BATTERY), |
|
|
"Sequential": ("sequential", QuestionnairePresentation.SEQUENTIAL), |
|
|
} |
|
|
|
|
|
survey_method_display = state.create( |
|
|
st.selectbox, |
|
|
"survey_method", |
|
|
"Questionnaire Method", |
|
|
options=list(survey_method_options.keys()), |
|
|
initial_value="Single item", |
|
|
help="Choose how to conduct the questionnaire: Single item (one at a time), Battery (all questions together), or Sequential (with conversation history)." |
|
|
) |
|
|
|
|
|
|
|
|
selected_method_name, selected_questionnaire_type = survey_method_options[survey_method_display] |
|
|
|
|
|
with st.container(border=True): |
|
|
|
|
|
if selected_questionnaire_type == QuestionnairePresentation.SINGLE_ITEM: |
|
|
num_questions = len(questionnaires._questions) |
|
|
num_previews = min(3, num_questions) |
|
|
|
|
|
if num_previews > 1: |
|
|
st.write(f"**Preview of first {num_previews} items:**") |
|
|
else: |
|
|
st.write("**Preview:**") |
|
|
|
|
|
for i in range(num_previews): |
|
|
if num_previews > 1: |
|
|
st.write(f"**Item {i+1}:**") |
|
|
|
|
|
current_system_prompt, current_prompt = questionnaires.get_prompt_for_questionnaire_type( |
|
|
selected_questionnaire_type, |
|
|
item_id=i |
|
|
) |
|
|
current_system_prompt = current_system_prompt.replace("\n", " \n") |
|
|
current_prompt = current_prompt.replace("\n", " \n") |
|
|
st.write(current_system_prompt) |
|
|
st.write(current_prompt) |
|
|
|
|
|
|
|
|
if i < num_previews - 1: |
|
|
st.divider() |
|
|
else: |
|
|
|
|
|
current_system_prompt, current_prompt = questionnaires.get_prompt_for_questionnaire_type(selected_questionnaire_type) |
|
|
current_system_prompt = current_system_prompt.replace("\n", " \n") |
|
|
current_prompt = current_prompt.replace("\n", " \n") |
|
|
st.write(current_system_prompt) |
|
|
st.write(current_prompt) |
|
|
|
|
|
|
|
|
if st.button("Confirm and Run Questionnaire", type="primary", use_container_width=True): |
|
|
st.write("Starting inference...") |
|
|
|
|
|
openai_api_key = "EMPTY" |
|
|
openai_api_base = "http://localhost:8000/v1" |
|
|
|
|
|
client = AsyncOpenAI(**st.session_state.client_config) |
|
|
|
|
|
inference_config = st.session_state.inference_config.copy() |
|
|
|
|
|
model_name = inference_config.pop("model") |
|
|
|
|
|
progress_text = st.empty() |
|
|
|
|
|
log_queue = queue.Queue() |
|
|
result_queue = queue.Queue() |
|
|
|
|
|
class QueueWriter: |
|
|
def __init__(self, q): |
|
|
self.q = q |
|
|
|
|
|
def write(self, message): |
|
|
if message.strip(): |
|
|
self.q.put(message) |
|
|
|
|
|
def flush(self): |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
def run_async_in_thread( |
|
|
result_q, client, questionnaires, model_name, survey_method_name, **inference_config |
|
|
): |
|
|
queue_writer = QueueWriter(log_queue) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
with redirect_stderr(queue_writer): |
|
|
|
|
|
if survey_method_name == "single_item": |
|
|
survey_func = conduct_survey_single_item |
|
|
elif survey_method_name == "battery": |
|
|
survey_func = conduct_survey_battery |
|
|
elif survey_method_name == "sequential": |
|
|
survey_func = conduct_survey_sequential |
|
|
else: |
|
|
survey_func = conduct_survey_single_item |
|
|
|
|
|
result = survey_func( |
|
|
client, |
|
|
llm_prompts=questionnaires, |
|
|
client_model_name=model_name, |
|
|
api_concurrency=100, |
|
|
**inference_config, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
result = e |
|
|
st.error(e) |
|
|
finally: |
|
|
result_q.put(result) |
|
|
|
|
|
while not log_queue.empty(): |
|
|
log_queue.get() |
|
|
while not result_queue.empty(): |
|
|
result_queue.get() |
|
|
|
|
|
|
|
|
survey_method_display = st.session_state.get("survey_method", "Single item") |
|
|
survey_method_options = { |
|
|
"Single item": ("single_item", QuestionnairePresentation.SINGLE_ITEM), |
|
|
"Battery": ("battery", QuestionnairePresentation.BATTERY), |
|
|
"Sequential": ("sequential", QuestionnairePresentation.SEQUENTIAL), |
|
|
} |
|
|
selected_method_name, _ = survey_method_options.get(survey_method_display, ("single_item", QuestionnairePresentation.SINGLE_ITEM)) |
|
|
|
|
|
thread = threading.Thread( |
|
|
target=run_async_in_thread, |
|
|
args=(result_queue, client, st.session_state.questionnaires, model_name, selected_method_name), |
|
|
kwargs=inference_config, |
|
|
) |
|
|
thread.start() |
|
|
|
|
|
all_questions_placeholder = st.empty() |
|
|
progress_placeholder = st.empty() |
|
|
|
|
|
while thread.is_alive(): |
|
|
try: |
|
|
|
|
|
|
|
|
log_message = log_queue.get_nowait() |
|
|
|
|
|
if "[A" not in log_message and "Processing Prompts" not in log_message: |
|
|
all_questions_placeholder.text(log_message.strip().replace("\r", "")) |
|
|
|
|
|
elif "Processing Prompts" in log_message: |
|
|
progress_placeholder.text(log_message.strip().replace("\r", "")) |
|
|
|
|
|
except queue.Empty: |
|
|
pass |
|
|
time.sleep(0.1) |
|
|
thread.join() |
|
|
|
|
|
all_questions_placeholder.empty() |
|
|
progress_placeholder.empty() |
|
|
|
|
|
try: |
|
|
final_output = result_queue.get_nowait() |
|
|
except queue.Empty: |
|
|
st.error("Could not retrieve result from the asynchronous task.") |
|
|
|
|
|
st.success("Finished inferencing!") |
|
|
|
|
|
responses = raw_responses(final_output) |
|
|
|
|
|
df = create_one_dataframe(responses) |
|
|
|
|
|
|
|
|
st.session_state.results_dataframe = df |
|
|
st.session_state.inference_completed = True |
|
|
|
|
|
st.dataframe(df) |
|
|
|
|
|
|
|
|
if "inference_completed" in st.session_state and st.session_state.inference_completed: |
|
|
st.divider() |
|
|
st.subheader("💾 Save Results") |
|
|
|
|
|
|
|
|
if "save_filename" not in st.session_state: |
|
|
st.session_state.save_filename = "questionnaire_results.csv" |
|
|
|
|
|
save_filename = st.text_input( |
|
|
"Save File", |
|
|
value=st.session_state.save_filename, |
|
|
key="save_filename_input", |
|
|
help="Enter the filename for the results. Should be a CSV file (e.g., results.csv)." |
|
|
) |
|
|
|
|
|
|
|
|
if save_filename and not save_filename.endswith('.csv'): |
|
|
save_filename = save_filename + '.csv' |
|
|
|
|
|
|
|
|
csv = st.session_state.results_dataframe.to_csv(index=False) |
|
|
|
|
|
st.download_button( |
|
|
label="Save Results", |
|
|
data=csv, |
|
|
file_name=save_filename if save_filename else "questionnaire_results.csv", |
|
|
mime="text/csv", |
|
|
type="primary", |
|
|
use_container_width=True, |
|
|
help="Click to save the results to your computer. You can choose the directory and filename in the save dialog." |
|
|
) |
|
|
|