qstn_gui / src /pages /03_Inference_Setting.py
MaxBenKre
correct gui code
448ed98
import streamlit as st
import json
from gui_elements.stateful_widget import StatefulWidgets
# --- Page Configuration ---
st.set_page_config(
page_title="Inference Settings",
layout="wide"
)
st.title("AsyncOpenAI API Client & Inference Configurator")
st.markdown("Use the widgets below to configure the `AsyncOpenAI` client and the inference parameters for an API call. Advanced or less common options can be added as a JSON object.")
st.divider()
# --- Column Layout ---
col1, col2 = st.columns(2)
defaults = {
# Client Config
"api_key": "", "organization": "", "project": "", "base_url": "",
"timeout": 20, "max_retries": 2,
"advanced_client_params_str": '',
# Inference Config
"model_name": "", "temperature": 1.0, "max_tokens": 1024,
"top_p": 1.0, "seed": 42,
"advanced_inference_params_str": ''
}
for key, value in defaults.items():
if key not in st.session_state:
st.session_state[key] = value
state = StatefulWidgets()
# ==============================================================================
# COLUMN 1: OPENAI CLIENT CONFIGURATION
# ==============================================================================
with col1:
st.header("1. Client Configuration")
with col2:
st.header("2. Inference Configuration")
with col1:
with st.container(border=True):
st.subheader("Core Settings")
api_key = state.create(
st.text_input,
"api_key",
"API Key",
initial_value="",
type="password",
placeholder="sk-...",
help="Your OpenAI API key. It is handled securely by Streamlit."
)
organization = state.create(
st.text_input,
"organization",
"Organization ID",
initial_value="",
placeholder="org-...",
help="Optional identifier for your organization."
)
project = state.create(
st.text_input,
"project",
"Project ID",
initial_value="",
placeholder="proj_...",
help="Optional identifier for your project."
)
base_url = state.create(
st.text_input,
"base_url",
"Base URL",
initial_value="",
placeholder="https://api.openai.com/v1",
help="The base URL for the API. Leave empty for the default."
)
timeout = state.create(
st.number_input,
"timeout",
"Timeout (seconds)",
initial_value=20,
min_value=1,
help="The timeout for API requests in seconds."
)
max_retries = state.create(
st.number_input,
"max_retries",
"Max Retries",
initial_value=2,
min_value=0,
help="The maximum number of times to retry a failed request."
)
with st.expander("Advanced Client Settings (JSON)"):
advanced_client_params_str = state.create(
st.text_area,
"advanced_client_params_str",
"JSON for other client parameters",
initial_value="",
placeholder='{\n "default_headers": {"X-Custom-Header": "value"}\n}',
height=150,
help='Enter any other client init parameters like "default_headers" or "default_query" as a valid JSON object.'
)
# ==============================================================================
# COLUMN 2: INFERENCE PARAMETERS
# ==============================================================================
with col2:
with st.container(border=True):
st.subheader("Core Settings")
model_name = state.create(
st.text_input,
"model_name",
"Model Name",
#initial_value="meta-llama/Llama-3.1-70B-Instruct",
placeholder="meta-llama/Llama-3.1-70B-Instruct",
help="The model to use for the inference call."
)
temperature = state.create(
st.slider,
"temperature",
"Temperature",
min_value=0.0,
max_value=2.0,
step=0.01,
initial_value=1.0,
help="Controls randomness. Lower values are more deterministic and less creative."
)
max_tokens = state.create(
st.number_input,
"max_tokens",
"Max Tokens",
initial_value=1024,
min_value=1,
help="The maximum number of tokens to generate in the completion."
)
top_p = state.create(
st.slider,
"top_p",
"Top P",
min_value=0.0,
max_value=1.0,
step=0.01,
initial_value=1.0,
help="Controls nucleus sampling. The model considers tokens with top_p probability mass."
)
seed = state.create(
st.number_input,
"seed",
"Seed",
initial_value=42,
min_value=0,
help="A specific seed for reproducibility of results."
)
with st.expander("Advanced Inference Settings (JSON)"):
advanced_inference_params_str = state.create(
st.text_area,
"advanced_inference_params_str",
"JSON for other inference parameters",
initial_value="",
placeholder='{\n "stop": ["\\n", " Human:"],\n "presence_penalty": 0\n}',
height=150,
help='Enter any other valid inference parameters like "stop", "logit_bias", or "frequency_penalty" as a JSON object.'
)
# ==============================================================================
# GENERATION AND DISPLAY LOGIC
# ==============================================================================
st.divider()
if st.button("Generate Configuration & Code", type="primary", use_container_width=True):
# --- Process Client Config ---
client_config = {
"api_key": api_key
}
# Add optional string parameters if they are not empty
if organization: client_config["organization"] = organization
if project: client_config["project"] = project
if base_url: client_config["base_url"] = base_url
# Add numeric parameters
client_config["timeout"] = timeout
client_config["max_retries"] = max_retries
try:
if advanced_client_params_str:
advanced_client_params = json.loads(advanced_client_params_str)
client_config.update(advanced_client_params)
except json.JSONDecodeError:
st.error("Invalid JSON detected in Advanced Client Settings. Please correct it.")
st.stop()
# --- Process Inference Config ---
inference_config = {
"model": model_name,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"seed": seed
}
try:
if advanced_inference_params_str:
advanced_inference_params = json.loads(advanced_inference_params_str)
inference_config.update(advanced_inference_params)
except json.JSONDecodeError:
st.error("Invalid JSON detected in Advanced Inference Settings. Please correct it.")
st.stop()
st.session_state.client_config = client_config
st.session_state.inference_config = inference_config
st.success("Configuration generated successfully!")