Spaces:
Sleeping
Sleeping
File size: 6,272 Bytes
d1c266e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import os
from dotenv import load_dotenv
load_dotenv()
import gradio as gr
from PIL import Image
from langchain_core.messages import HumanMessage, AIMessage
from graph import build_graph
from state import WorkflowState
diagnosis_graph = None
reply_graph = None
initialization_error = None
try:
diagnosis_graph, reply_graph = build_graph()
print("Gradio app and graphs initialized successfully.")
except Exception as e:
initialization_error = f"CRITICAL ERROR: Could not build graphs. {e}. Check API keys."
print(initialization_error)
def convert_history_to_langchain(chat_history):
"""Converts Gradio history to Langchain message list."""
messages = []
for user_msg, ai_msg in chat_history:
if user_msg is not None:
if isinstance(user_msg, (dict, tuple)):
messages.append(HumanMessage(content="User uploaded an image."))
else:
messages.append(HumanMessage(content=user_msg))
if ai_msg is not None:
messages.append(AIMessage(content=ai_msg))
return messages
def reset_state_on_start():
"""Fresh empty state for every new diagnosis."""
return WorkflowState(
image=None,
chat_history=[],
disease_prediction="",
symptoms_to_check=[],
symptoms_confirmed=[],
current_symptom_index=0,
treatment_info="",
final_diagnosis=""
)
def chat_fn(message: str, chat_history: list, agent_state: dict, img_upload: Image):
"""
Handles user input and manages the agent's workflow state.
"""
if initialization_error:
chat_history.append((message, initialization_error))
yield chat_history, {}, gr.update(value=None, interactive=True), gr.update(value="", interactive=True)
return
if not agent_state or agent_state.get("final_diagnosis"):
current_state = reset_state_on_start()
else:
current_state = agent_state
chat_history = chat_history or []
is_new_diagnosis = False
if img_upload and (message.lower().strip() == "start" or message == ""):
print("--- Running NEW diagnosis flow ---")
is_new_diagnosis = True
current_state = reset_state_on_start()
current_state["image"] = img_upload
graph_to_run = diagnosis_graph
chat_history.append([(img_upload,), None])
elif current_state.get("symptoms_to_check") and not current_state.get("final_diagnosis"):
print("--- Running REPLY symptom loop flow ---")
graph_to_run = reply_graph
chat_history.append([message, None])
else:
if message:
chat_history.append([message, None])
chat_history[-1][1] = "Hello! Please upload an image, then click 'Start Diagnosis'."
yield chat_history, agent_state, gr.update(value=img_upload, interactive=True), gr.update(value="", interactive=True)
return
current_state["chat_history"] = convert_history_to_langchain(chat_history)
try:
final_state = {}
for step in graph_to_run.stream(current_state, {"recursion_limit": 100}):
final_state = list(step.values())[0]
ai_response = final_state['chat_history'][-1].content
chat_history[-1][1] = ai_response
if final_state.get("final_diagnosis"):
print("--- Agent Flow ENDED ---")
yield chat_history, {}, gr.update(value=None, interactive=True), gr.update(value="", interactive=True)
else:
yield chat_history, final_state, gr.update(value=img_upload, interactive=False), gr.update(value="", interactive=True)
except Exception as e:
print(f"--- Graph Runtime Error --- \n{e}")
error_msg = f"A runtime error occurred: {e}. Please check the console."
chat_history[-1][1] = error_msg
yield chat_history, {}, gr.update(value=None, interactive=True), gr.update(value="", interactive=True)
def clear_all():
"""Clears chat, state, and image."""
return [], {}, None, ""
with gr.Blocks(theme=gr.themes.Soft(), title="Agentic Skin AI") as demo:
gr.Markdown("# 🩺 Multimodal Agentic Skin Disease AI")
gr.Markdown(
"**Disclaimer:** This is a demo project and NOT a medical device. "
"Consult a real doctor for any medical concerns."
)
agent_state = gr.State({})
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 1. Image Input")
img_upload = gr.Image(type="pil", label="Upload Skin Image", interactive=True)
btn_start = gr.Button("Start Diagnosis", variant="primary")
btn_clear = gr.Button("Clear All & Start New")
gr.Markdown(
"**Instructions:**\n"
"1. Upload an image.\n"
"2. Click **Start Diagnosis**.\n"
"3. Answer the agent's questions in the textbox."
)
with gr.Column(scale=2):
gr.Markdown("### 2. Agent Conversation")
chatbot = gr.Chatbot(label="Agent Conversation", height=500, bubble_full_width=False, avatar_images=None)
txt_msg = gr.Textbox(
label="Your message (Yes / No / etc.)",
placeholder="Answer the agent's questions here...",
interactive=True
)
txt_msg.submit(
fn=chat_fn,
inputs=[txt_msg, chatbot, agent_state, img_upload],
outputs=[chatbot, agent_state, img_upload, txt_msg]
)
btn_start.click(
fn=chat_fn,
inputs=[gr.Textbox(value="start", visible=False), chatbot, agent_state, img_upload],
outputs=[chatbot, agent_state, img_upload, txt_msg]
)
btn_clear.click(
fn=clear_all,
inputs=None,
outputs=[chatbot, agent_state, img_upload, txt_msg]
)
img_upload.upload(
fn=lambda: ([], {}, "Click 'Start Diagnosis' to begin."),
inputs=None,
outputs=[chatbot, agent_state, txt_msg]
)
if __name__ == "__main__":
if initialization_error:
print("\n\n*** CANNOT LAUNCH APP: Agent failed to initialize. ***")
print(f"*** ERROR: {initialization_error} ***")
else:
demo.launch(debug=True)
|