Y Phung Nguyen commited on
Commit
c8562d7
·
1 Parent(s): 3b38a6c

Optimise QA round followup extensive

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. pipeline.py +49 -3
  3. supervisor.py +5 -0
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  .env
2
  .setup.txt
3
- __pycache__/
 
 
1
  .env
2
  .setup.txt
3
+ __pycache__/
4
+ sample.txt
pipeline.py CHANGED
@@ -5,6 +5,7 @@ import time
5
  import logging
6
  import threading
7
  import concurrent.futures
 
8
  import gradio as gr
9
  import spaces
10
  from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
@@ -118,6 +119,46 @@ def _build_refined_query(base_query: str, insights: dict, insights_block: str) -
118
  return "\n\n".join([section for section in sections if section])
119
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def _start_clinical_intake_session(session_id: str, plan: dict, base_query: str, original_language: str):
122
  questions = plan.get("questions", []) or []
123
  if not questions:
@@ -135,7 +176,8 @@ def _start_clinical_intake_session(session_id: str, plan: dict, base_query: str,
135
  "answers": [],
136
  "decision_reason": plan.get("decision_reason", ""),
137
  "initial_hypotheses": plan.get("initial_hypotheses", []),
138
- "started_at": time.time()
 
139
  }
140
  _set_clinical_intake_state(session_id, state)
141
  first_prompt = _format_intake_question(
@@ -144,6 +186,8 @@ def _start_clinical_intake_session(session_id: str, plan: dict, base_query: str,
144
  max_rounds=max_rounds,
145
  target_lang=state["original_language"]
146
  )
 
 
147
  return first_prompt
148
 
149
 
@@ -193,6 +237,8 @@ def _handle_clinical_answer(session_id: str, answer_text: str):
193
  max_rounds=state["max_rounds"],
194
  target_lang=state["original_language"]
195
  )
 
 
196
  return {"type": "question", "prompt": prompt}
197
 
198
 
@@ -235,7 +281,7 @@ def stream_chat(
235
  def elapsed():
236
  return time.time() - session_start
237
 
238
- user_id = request.session_hash
239
  index_dir = f"./{user_id}_index"
240
  has_rag_index = os.path.exists(index_dir)
241
 
@@ -285,7 +331,7 @@ def stream_chat(
285
  if not enable_clinical_intake:
286
  _clear_clinical_intake_state(user_id)
287
  else:
288
- intake_state = _get_clinical_intake_state(user_id)
289
  if intake_state and intake_state.get("awaiting_answer"):
290
  logger.info("[INTAKE] Awaiting patient response - processing answer")
291
  intake_result = _handle_clinical_answer(user_id, message)
 
5
  import logging
6
  import threading
7
  import concurrent.futures
8
+ import hashlib
9
  import gradio as gr
10
  import spaces
11
  from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
 
119
  return "\n\n".join([section for section in sections if section])
120
 
121
 
122
+ def _hash_prompt_text(text: str) -> str:
123
+ if not text:
124
+ return ""
125
+ digest = hashlib.sha1()
126
+ digest.update(text.strip().encode("utf-8"))
127
+ return digest.hexdigest()
128
+
129
+
130
+ def _extract_pending_intake_prompt(history: list) -> str:
131
+ if not history:
132
+ return ""
133
+ for turn in reversed(history):
134
+ if turn.get("role") != "assistant":
135
+ continue
136
+ content = turn.get("content", "")
137
+ if content.startswith("🩺 Question for clarity"):
138
+ return content
139
+ return ""
140
+
141
+
142
+ def _rehydrate_intake_state(session_id: str, history: list):
143
+ state = _get_clinical_intake_state(session_id)
144
+ if state or not history:
145
+ return state
146
+ pending_prompt = _extract_pending_intake_prompt(history)
147
+ if not pending_prompt:
148
+ return None
149
+ prompt_hash = _hash_prompt_text(pending_prompt)
150
+ if not prompt_hash:
151
+ return None
152
+ with _clinical_intake_lock:
153
+ for existing_id, existing_state in list(_clinical_intake_sessions.items()):
154
+ if existing_state.get("awaiting_answer") and existing_state.get("last_prompt_hash") == prompt_hash:
155
+ if existing_id != session_id:
156
+ _clinical_intake_sessions.pop(existing_id, None)
157
+ _clinical_intake_sessions[session_id] = existing_state
158
+ return existing_state
159
+ return None
160
+
161
+
162
  def _start_clinical_intake_session(session_id: str, plan: dict, base_query: str, original_language: str):
163
  questions = plan.get("questions", []) or []
164
  if not questions:
 
176
  "answers": [],
177
  "decision_reason": plan.get("decision_reason", ""),
178
  "initial_hypotheses": plan.get("initial_hypotheses", []),
179
+ "started_at": time.time(),
180
+ "last_prompt_hash": ""
181
  }
182
  _set_clinical_intake_state(session_id, state)
183
  first_prompt = _format_intake_question(
 
186
  max_rounds=max_rounds,
187
  target_lang=state["original_language"]
188
  )
189
+ state["last_prompt_hash"] = _hash_prompt_text(first_prompt)
190
+ _set_clinical_intake_state(session_id, state)
191
  return first_prompt
192
 
193
 
 
237
  max_rounds=state["max_rounds"],
238
  target_lang=state["original_language"]
239
  )
240
+ state["last_prompt_hash"] = _hash_prompt_text(prompt)
241
+ _set_clinical_intake_state(session_id, state)
242
  return {"type": "question", "prompt": prompt}
243
 
244
 
 
281
  def elapsed():
282
  return time.time() - session_start
283
 
284
+ user_id = request.session_hash or "anonymous"
285
  index_dir = f"./{user_id}_index"
286
  has_rag_index = os.path.exists(index_dir)
287
 
 
331
  if not enable_clinical_intake:
332
  _clear_clinical_intake_state(user_id)
333
  else:
334
+ intake_state = _rehydrate_intake_state(user_id, history)
335
  if intake_state and intake_state.get("awaiting_answer"):
336
  logger.info("[INTAKE] Awaiting patient response - processing answer")
337
  intake_result = _handle_clinical_answer(user_id, message)
supervisor.py CHANGED
@@ -168,12 +168,17 @@ def _prepare_clinical_question_plan(plan: dict, safe_rounds: int) -> dict:
168
  if not isinstance(questions, list):
169
  questions = []
170
  cleaned = []
 
171
  for idx, raw in enumerate(questions):
172
  if not isinstance(raw, dict):
173
  continue
174
  question_text = (raw.get("question") or "").strip()
175
  if not question_text:
176
  continue
 
 
 
 
177
  entry = dict(raw)
178
  entry["question"] = question_text
179
  entry["order"] = entry.get("order") or raw.get("id") or (idx + 1)
 
168
  if not isinstance(questions, list):
169
  questions = []
170
  cleaned = []
171
+ seen = set()
172
  for idx, raw in enumerate(questions):
173
  if not isinstance(raw, dict):
174
  continue
175
  question_text = (raw.get("question") or "").strip()
176
  if not question_text:
177
  continue
178
+ normalized = question_text.lower()
179
+ if normalized in seen:
180
+ continue
181
+ seen.add(normalized)
182
  entry = dict(raw)
183
  entry["question"] = question_text
184
  entry["order"] = entry.get("order") or raw.get("id") or (idx + 1)