Spaces:
Running
Running
Merge pull request #11 from The-Obstacle-Is-The-Way/feat/phase7-hypothesis-agent
Browse files- src/agents/hypothesis_agent.py +144 -0
- src/orchestrator_magentic.py +152 -121
- src/prompts/hypothesis.py +68 -0
- src/utils/models.py +41 -0
- src/utils/text_utils.py +132 -0
- tests/unit/agents/test_hypothesis_agent.py +105 -0
- tests/unit/utils/test_text_utils.py +133 -0
src/agents/hypothesis_agent.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hypothesis agent for mechanistic reasoning."""
|
| 2 |
+
|
| 3 |
+
from collections.abc import AsyncIterable
|
| 4 |
+
from typing import TYPE_CHECKING, Any
|
| 5 |
+
|
| 6 |
+
from agent_framework import (
|
| 7 |
+
AgentRunResponse,
|
| 8 |
+
AgentRunResponseUpdate,
|
| 9 |
+
AgentThread,
|
| 10 |
+
BaseAgent,
|
| 11 |
+
ChatMessage,
|
| 12 |
+
Role,
|
| 13 |
+
)
|
| 14 |
+
from pydantic_ai import Agent
|
| 15 |
+
|
| 16 |
+
from src.agent_factory.judges import get_model
|
| 17 |
+
from src.prompts.hypothesis import SYSTEM_PROMPT, format_hypothesis_prompt
|
| 18 |
+
from src.utils.models import HypothesisAssessment
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from src.services.embeddings import EmbeddingService
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class HypothesisAgent(BaseAgent): # type: ignore[misc]
|
| 25 |
+
"""Generates mechanistic hypotheses based on evidence."""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
evidence_store: dict[str, Any],
|
| 30 |
+
embedding_service: "EmbeddingService | None" = None, # NEW: for diverse selection
|
| 31 |
+
) -> None:
|
| 32 |
+
super().__init__(
|
| 33 |
+
name="HypothesisAgent",
|
| 34 |
+
description="Generates scientific hypotheses about drug mechanisms to guide research",
|
| 35 |
+
)
|
| 36 |
+
self._evidence_store = evidence_store
|
| 37 |
+
self._embeddings = embedding_service # Used for MMR evidence selection
|
| 38 |
+
self._agent: Agent[None, HypothesisAssessment] | None = None # Lazy init
|
| 39 |
+
|
| 40 |
+
def _get_agent(self) -> Agent[None, HypothesisAssessment]:
|
| 41 |
+
"""Lazy initialization of LLM agent to avoid requiring API keys at import."""
|
| 42 |
+
if self._agent is None:
|
| 43 |
+
self._agent = Agent(
|
| 44 |
+
model=get_model(), # Uses configured LLM (OpenAI/Anthropic)
|
| 45 |
+
output_type=HypothesisAssessment,
|
| 46 |
+
system_prompt=SYSTEM_PROMPT,
|
| 47 |
+
)
|
| 48 |
+
return self._agent
|
| 49 |
+
|
| 50 |
+
async def run(
|
| 51 |
+
self,
|
| 52 |
+
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
| 53 |
+
*,
|
| 54 |
+
thread: AgentThread | None = None,
|
| 55 |
+
**kwargs: Any,
|
| 56 |
+
) -> AgentRunResponse:
|
| 57 |
+
"""Generate hypotheses based on current evidence."""
|
| 58 |
+
# Extract query
|
| 59 |
+
query = self._extract_query(messages)
|
| 60 |
+
|
| 61 |
+
# Get current evidence
|
| 62 |
+
evidence = self._evidence_store.get("current", [])
|
| 63 |
+
|
| 64 |
+
if not evidence:
|
| 65 |
+
return AgentRunResponse(
|
| 66 |
+
messages=[
|
| 67 |
+
ChatMessage(
|
| 68 |
+
role=Role.ASSISTANT,
|
| 69 |
+
text="No evidence available yet. Search for evidence first.",
|
| 70 |
+
)
|
| 71 |
+
],
|
| 72 |
+
response_id="hypothesis-no-evidence",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Generate hypotheses with diverse evidence selection
|
| 76 |
+
prompt = await format_hypothesis_prompt(query, evidence, embeddings=self._embeddings)
|
| 77 |
+
result = await self._get_agent().run(prompt)
|
| 78 |
+
assessment = result.output # pydantic-ai returns .output for structured output
|
| 79 |
+
|
| 80 |
+
# Store hypotheses in shared context
|
| 81 |
+
existing = self._evidence_store.get("hypotheses", [])
|
| 82 |
+
self._evidence_store["hypotheses"] = existing + assessment.hypotheses
|
| 83 |
+
|
| 84 |
+
# Format response
|
| 85 |
+
response_text = self._format_response(assessment)
|
| 86 |
+
|
| 87 |
+
return AgentRunResponse(
|
| 88 |
+
messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
|
| 89 |
+
response_id=f"hypothesis-{len(assessment.hypotheses)}",
|
| 90 |
+
additional_properties={"assessment": assessment.model_dump()},
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def _format_response(self, assessment: HypothesisAssessment) -> str:
|
| 94 |
+
"""Format hypothesis assessment as markdown."""
|
| 95 |
+
lines = ["## Generated Hypotheses\n"]
|
| 96 |
+
|
| 97 |
+
for i, h in enumerate(assessment.hypotheses, 1):
|
| 98 |
+
lines.append(f"### Hypothesis {i} (Confidence: {h.confidence:.0%})")
|
| 99 |
+
lines.append(f"**Mechanism**: {h.drug} -> {h.target} -> {h.pathway} -> {h.effect}")
|
| 100 |
+
lines.append(f"**Suggested searches**: {', '.join(h.search_suggestions)}\n")
|
| 101 |
+
|
| 102 |
+
if assessment.primary_hypothesis:
|
| 103 |
+
lines.append("### Primary Hypothesis")
|
| 104 |
+
h = assessment.primary_hypothesis
|
| 105 |
+
lines.append(f"{h.drug} -> {h.target} -> {h.pathway} -> {h.effect}\n")
|
| 106 |
+
|
| 107 |
+
if assessment.knowledge_gaps:
|
| 108 |
+
lines.append("### Knowledge Gaps")
|
| 109 |
+
for gap in assessment.knowledge_gaps:
|
| 110 |
+
lines.append(f"- {gap}")
|
| 111 |
+
|
| 112 |
+
if assessment.recommended_searches:
|
| 113 |
+
lines.append("\n### Recommended Next Searches")
|
| 114 |
+
for search in assessment.recommended_searches:
|
| 115 |
+
lines.append(f"- `{search}`")
|
| 116 |
+
|
| 117 |
+
return "\n".join(lines)
|
| 118 |
+
|
| 119 |
+
def _extract_query(
|
| 120 |
+
self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None
|
| 121 |
+
) -> str:
|
| 122 |
+
"""Extract query from messages."""
|
| 123 |
+
if isinstance(messages, str):
|
| 124 |
+
return messages
|
| 125 |
+
elif isinstance(messages, ChatMessage):
|
| 126 |
+
return messages.text or ""
|
| 127 |
+
elif isinstance(messages, list):
|
| 128 |
+
for msg in reversed(messages):
|
| 129 |
+
if isinstance(msg, ChatMessage) and msg.role == Role.USER:
|
| 130 |
+
return msg.text or ""
|
| 131 |
+
elif isinstance(msg, str):
|
| 132 |
+
return msg
|
| 133 |
+
return ""
|
| 134 |
+
|
| 135 |
+
async def run_stream(
|
| 136 |
+
self,
|
| 137 |
+
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
| 138 |
+
*,
|
| 139 |
+
thread: AgentThread | None = None,
|
| 140 |
+
**kwargs: Any,
|
| 141 |
+
) -> AsyncIterable[AgentRunResponseUpdate]:
|
| 142 |
+
"""Streaming wrapper."""
|
| 143 |
+
result = await self.run(messages, thread=thread, **kwargs)
|
| 144 |
+
yield AgentRunResponseUpdate(messages=result.messages, response_id=result.response_id)
|
src/orchestrator_magentic.py
CHANGED
|
@@ -6,8 +6,13 @@ the agent_framework provides an AnthropicChatClient.
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from collections.abc import AsyncGenerator
|
|
|
|
| 9 |
|
| 10 |
import structlog
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from agent_framework import (
|
| 12 |
MagenticAgentDeltaEvent,
|
| 13 |
MagenticAgentMessageEvent,
|
|
@@ -18,6 +23,7 @@ from agent_framework import (
|
|
| 18 |
)
|
| 19 |
from agent_framework.openai import OpenAIChatClient
|
| 20 |
|
|
|
|
| 21 |
from src.agents.judge_agent import JudgeAgent
|
| 22 |
from src.agents.search_agent import SearchAgent
|
| 23 |
from src.orchestrator import JudgeHandlerProtocol, SearchHandlerProtocol
|
|
@@ -28,6 +34,11 @@ from src.utils.models import AgentEvent, Evidence
|
|
| 28 |
logger = structlog.get_logger()
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
class MagenticOrchestrator:
|
| 32 |
"""
|
| 33 |
Magentic-based orchestrator - same API as Orchestrator.
|
|
@@ -51,50 +62,38 @@ class MagenticOrchestrator:
|
|
| 51 |
self._max_rounds = max_rounds
|
| 52 |
self._evidence_store: dict[str, list[Evidence]] = {"current": []}
|
| 53 |
|
| 54 |
-
|
| 55 |
-
"""
|
| 56 |
-
Run the Magentic workflow - same API as simple Orchestrator.
|
| 57 |
-
|
| 58 |
-
Yields AgentEvent objects for real-time UI updates.
|
| 59 |
-
"""
|
| 60 |
-
logger.info("Starting Magentic orchestrator", query=query)
|
| 61 |
-
|
| 62 |
-
yield AgentEvent(
|
| 63 |
-
type="started",
|
| 64 |
-
message=f"Starting research (Magentic mode): {query}",
|
| 65 |
-
iteration=0,
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
# Initialize embedding service (optional)
|
| 69 |
-
embedding_service = None
|
| 70 |
try:
|
| 71 |
from src.services.embeddings import get_embedding_service
|
| 72 |
|
| 73 |
-
|
| 74 |
logger.info("Embedding service enabled")
|
|
|
|
| 75 |
except ImportError:
|
| 76 |
logger.info("Embedding service not available (dependencies missing)")
|
| 77 |
except Exception as e:
|
| 78 |
logger.warning("Failed to initialize embedding service", error=str(e))
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
judge_agent
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# Note: MagenticBuilder requires OpenAI - validate key exists
|
| 88 |
if not settings.openai_api_key:
|
| 89 |
raise ConfigurationError(
|
| 90 |
"Magentic mode requires OPENAI_API_KEY. "
|
| 91 |
"Set the key or use mode='simple' with Anthropic."
|
| 92 |
)
|
| 93 |
|
| 94 |
-
|
| 95 |
MagenticBuilder()
|
| 96 |
.participants(
|
| 97 |
searcher=search_agent,
|
|
|
|
| 98 |
judge=judge_agent,
|
| 99 |
)
|
| 100 |
.with_standard_manager(
|
|
@@ -108,114 +107,67 @@ class MagenticOrchestrator:
|
|
| 108 |
.build()
|
| 109 |
)
|
| 110 |
|
| 111 |
-
|
|
|
|
| 112 |
semantic_note = ""
|
| 113 |
-
if
|
| 114 |
semantic_note = """
|
| 115 |
The system has semantic search enabled. When evidence is found:
|
| 116 |
1. Related concepts will be automatically surfaced
|
| 117 |
2. Duplicates are removed by meaning, not just URL
|
| 118 |
3. Use the surfaced related concepts to refine searches
|
| 119 |
"""
|
| 120 |
-
|
| 121 |
-
task = f"""Research drug repurposing opportunities for: {query}
|
| 122 |
{semantic_note}
|
| 123 |
-
|
| 124 |
-
1.
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
-
|
| 133 |
-
-
|
| 134 |
-
- Specific drug candidates
|
| 135 |
"""
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
iteration = 0
|
| 138 |
try:
|
| 139 |
-
# workflow.run_stream returns an async generator of workflow events
|
| 140 |
-
# We use 'await' in the for loop for async generator
|
| 141 |
async for event in workflow.run_stream(task):
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
event.message.text
|
| 148 |
-
if event.message and hasattr(event.message, "text")
|
| 149 |
-
else ""
|
| 150 |
-
)
|
| 151 |
-
# kind might be 'plan', 'instruction', etc.
|
| 152 |
-
kind = getattr(event, "kind", "manager")
|
| 153 |
-
|
| 154 |
-
if message_text:
|
| 155 |
-
yield AgentEvent(
|
| 156 |
-
type="judging",
|
| 157 |
-
message=f"Manager ({kind}): {message_text[:100]}...",
|
| 158 |
-
iteration=iteration,
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
elif isinstance(event, MagenticAgentMessageEvent):
|
| 162 |
-
# Complete agent response
|
| 163 |
-
iteration += 1
|
| 164 |
-
agent_name = event.agent_id or "unknown"
|
| 165 |
-
msg_text = (
|
| 166 |
-
event.message.text
|
| 167 |
-
if event.message and hasattr(event.message, "text")
|
| 168 |
-
else ""
|
| 169 |
-
)
|
| 170 |
-
|
| 171 |
-
if "search" in agent_name.lower():
|
| 172 |
-
# Check if we found evidence (based on SearchAgent logic)
|
| 173 |
-
yield AgentEvent(
|
| 174 |
-
type="search_complete",
|
| 175 |
-
message=f"Search agent: {msg_text[:100]}...",
|
| 176 |
-
iteration=iteration,
|
| 177 |
-
)
|
| 178 |
-
elif "judge" in agent_name.lower():
|
| 179 |
-
yield AgentEvent(
|
| 180 |
-
type="judge_complete",
|
| 181 |
-
message=f"Judge agent: {msg_text[:100]}...",
|
| 182 |
-
iteration=iteration,
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
elif isinstance(event, MagenticFinalResultEvent):
|
| 186 |
-
# Final workflow result
|
| 187 |
-
final_text = (
|
| 188 |
-
event.message.text
|
| 189 |
-
if event.message and hasattr(event.message, "text")
|
| 190 |
-
else "No result"
|
| 191 |
-
)
|
| 192 |
-
yield AgentEvent(
|
| 193 |
-
type="complete",
|
| 194 |
-
message=final_text,
|
| 195 |
-
data={"iterations": iteration},
|
| 196 |
-
iteration=iteration,
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
elif isinstance(event, MagenticAgentDeltaEvent):
|
| 200 |
-
# Streaming token chunks from agents (optional "typing" effect)
|
| 201 |
-
# Only emit if we have actual text content
|
| 202 |
-
if event.text:
|
| 203 |
-
yield AgentEvent(
|
| 204 |
-
type="streaming",
|
| 205 |
-
message=event.text,
|
| 206 |
-
data={"agent_id": event.agent_id},
|
| 207 |
-
iteration=iteration,
|
| 208 |
-
)
|
| 209 |
-
|
| 210 |
-
elif isinstance(event, WorkflowOutputEvent):
|
| 211 |
-
# Alternative final output event
|
| 212 |
-
if event.data:
|
| 213 |
-
yield AgentEvent(
|
| 214 |
-
type="complete",
|
| 215 |
-
message=str(event.data),
|
| 216 |
-
iteration=iteration,
|
| 217 |
-
)
|
| 218 |
-
|
| 219 |
except Exception as e:
|
| 220 |
logger.error("Magentic workflow failed", error=str(e))
|
| 221 |
yield AgentEvent(
|
|
@@ -223,3 +175,82 @@ Focus on finding:
|
|
| 223 |
message=f"Workflow error: {e!s}",
|
| 224 |
iteration=iteration,
|
| 225 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
from collections.abc import AsyncGenerator
|
| 9 |
+
from typing import TYPE_CHECKING, Any
|
| 10 |
|
| 11 |
import structlog
|
| 12 |
+
|
| 13 |
+
if TYPE_CHECKING:
|
| 14 |
+
from src.services.embeddings import EmbeddingService
|
| 15 |
+
|
| 16 |
from agent_framework import (
|
| 17 |
MagenticAgentDeltaEvent,
|
| 18 |
MagenticAgentMessageEvent,
|
|
|
|
| 23 |
)
|
| 24 |
from agent_framework.openai import OpenAIChatClient
|
| 25 |
|
| 26 |
+
from src.agents.hypothesis_agent import HypothesisAgent
|
| 27 |
from src.agents.judge_agent import JudgeAgent
|
| 28 |
from src.agents.search_agent import SearchAgent
|
| 29 |
from src.orchestrator import JudgeHandlerProtocol, SearchHandlerProtocol
|
|
|
|
| 34 |
logger = structlog.get_logger()
|
| 35 |
|
| 36 |
|
| 37 |
+
def _truncate(text: str, max_len: int = 100) -> str:
|
| 38 |
+
"""Truncate text with ellipsis only if needed."""
|
| 39 |
+
return f"{text[:max_len]}..." if len(text) > max_len else text
|
| 40 |
+
|
| 41 |
+
|
| 42 |
class MagenticOrchestrator:
|
| 43 |
"""
|
| 44 |
Magentic-based orchestrator - same API as Orchestrator.
|
|
|
|
| 62 |
self._max_rounds = max_rounds
|
| 63 |
self._evidence_store: dict[str, list[Evidence]] = {"current": []}
|
| 64 |
|
| 65 |
+
def _init_embedding_service(self) -> "EmbeddingService | None":
|
| 66 |
+
"""Initialize embedding service if available."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
try:
|
| 68 |
from src.services.embeddings import get_embedding_service
|
| 69 |
|
| 70 |
+
service = get_embedding_service()
|
| 71 |
logger.info("Embedding service enabled")
|
| 72 |
+
return service
|
| 73 |
except ImportError:
|
| 74 |
logger.info("Embedding service not available (dependencies missing)")
|
| 75 |
except Exception as e:
|
| 76 |
logger.warning("Failed to initialize embedding service", error=str(e))
|
| 77 |
+
return None
|
| 78 |
|
| 79 |
+
def _build_workflow(
|
| 80 |
+
self,
|
| 81 |
+
search_agent: SearchAgent,
|
| 82 |
+
hypothesis_agent: HypothesisAgent,
|
| 83 |
+
judge_agent: JudgeAgent,
|
| 84 |
+
) -> Any:
|
| 85 |
+
"""Build the Magentic workflow with participants."""
|
|
|
|
| 86 |
if not settings.openai_api_key:
|
| 87 |
raise ConfigurationError(
|
| 88 |
"Magentic mode requires OPENAI_API_KEY. "
|
| 89 |
"Set the key or use mode='simple' with Anthropic."
|
| 90 |
)
|
| 91 |
|
| 92 |
+
return (
|
| 93 |
MagenticBuilder()
|
| 94 |
.participants(
|
| 95 |
searcher=search_agent,
|
| 96 |
+
hypothesizer=hypothesis_agent,
|
| 97 |
judge=judge_agent,
|
| 98 |
)
|
| 99 |
.with_standard_manager(
|
|
|
|
| 107 |
.build()
|
| 108 |
)
|
| 109 |
|
| 110 |
+
def _format_task(self, query: str, has_embeddings: bool) -> str:
|
| 111 |
+
"""Format the task instruction for the manager."""
|
| 112 |
semantic_note = ""
|
| 113 |
+
if has_embeddings:
|
| 114 |
semantic_note = """
|
| 115 |
The system has semantic search enabled. When evidence is found:
|
| 116 |
1. Related concepts will be automatically surfaced
|
| 117 |
2. Duplicates are removed by meaning, not just URL
|
| 118 |
3. Use the surfaced related concepts to refine searches
|
| 119 |
"""
|
| 120 |
+
return f"""Research drug repurposing opportunities for: {query}
|
|
|
|
| 121 |
{semantic_note}
|
| 122 |
+
Workflow:
|
| 123 |
+
1. SearcherAgent: Find initial evidence from PubMed and web. SEND ONLY A SIMPLE KEYWORD QUERY.
|
| 124 |
+
2. HypothesisAgent: Generate mechanistic hypotheses (Drug -> Target -> Pathway -> Effect).
|
| 125 |
+
3. SearcherAgent: Use hypothesis-suggested queries for targeted search.
|
| 126 |
+
4. JudgeAgent: Evaluate if evidence supports hypotheses.
|
| 127 |
+
5. Repeat until confident or max rounds.
|
| 128 |
+
|
| 129 |
+
Focus on:
|
| 130 |
+
- Identifying specific molecular targets
|
| 131 |
+
- Understanding mechanism of action
|
| 132 |
+
- Finding supporting/contradicting evidence for hypotheses
|
|
|
|
| 133 |
"""
|
| 134 |
|
| 135 |
+
async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
|
| 136 |
+
"""
|
| 137 |
+
Run the Magentic workflow - same API as simple Orchestrator.
|
| 138 |
+
|
| 139 |
+
Yields AgentEvent objects for real-time UI updates.
|
| 140 |
+
"""
|
| 141 |
+
logger.info("Starting Magentic orchestrator", query=query)
|
| 142 |
+
|
| 143 |
+
yield AgentEvent(
|
| 144 |
+
type="started",
|
| 145 |
+
message=f"Starting research (Magentic mode): {query}",
|
| 146 |
+
iteration=0,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Initialize services and agents
|
| 150 |
+
embedding_service = self._init_embedding_service()
|
| 151 |
+
search_agent = SearchAgent(
|
| 152 |
+
self._search_handler, self._evidence_store, embedding_service=embedding_service
|
| 153 |
+
)
|
| 154 |
+
judge_agent = JudgeAgent(self._judge_handler, self._evidence_store)
|
| 155 |
+
hypothesis_agent = HypothesisAgent(
|
| 156 |
+
self._evidence_store, embedding_service=embedding_service
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Build workflow and task
|
| 160 |
+
workflow = self._build_workflow(search_agent, hypothesis_agent, judge_agent)
|
| 161 |
+
task = self._format_task(query, embedding_service is not None)
|
| 162 |
+
|
| 163 |
iteration = 0
|
| 164 |
try:
|
|
|
|
|
|
|
| 165 |
async for event in workflow.run_stream(task):
|
| 166 |
+
agent_event = self._process_event(event, iteration)
|
| 167 |
+
if agent_event:
|
| 168 |
+
if isinstance(event, MagenticAgentMessageEvent):
|
| 169 |
+
iteration += 1
|
| 170 |
+
yield agent_event
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
except Exception as e:
|
| 172 |
logger.error("Magentic workflow failed", error=str(e))
|
| 173 |
yield AgentEvent(
|
|
|
|
| 175 |
message=f"Workflow error: {e!s}",
|
| 176 |
iteration=iteration,
|
| 177 |
)
|
| 178 |
+
|
| 179 |
+
def _process_event(self, event: Any, iteration: int) -> AgentEvent | None:
|
| 180 |
+
"""Process a workflow event and return an AgentEvent if applicable."""
|
| 181 |
+
if isinstance(event, MagenticOrchestratorMessageEvent):
|
| 182 |
+
message_text = (
|
| 183 |
+
event.message.text if event.message and hasattr(event.message, "text") else ""
|
| 184 |
+
)
|
| 185 |
+
kind = getattr(event, "kind", "manager")
|
| 186 |
+
if message_text:
|
| 187 |
+
return AgentEvent(
|
| 188 |
+
type="judging",
|
| 189 |
+
message=f"Manager ({kind}): {_truncate(message_text)}",
|
| 190 |
+
iteration=iteration,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
elif isinstance(event, MagenticAgentMessageEvent):
|
| 194 |
+
agent_name = event.agent_id or "unknown"
|
| 195 |
+
msg_text = (
|
| 196 |
+
event.message.text if event.message and hasattr(event.message, "text") else ""
|
| 197 |
+
)
|
| 198 |
+
return self._agent_message_event(agent_name, msg_text, iteration + 1)
|
| 199 |
+
|
| 200 |
+
elif isinstance(event, MagenticFinalResultEvent):
|
| 201 |
+
final_text = (
|
| 202 |
+
event.message.text
|
| 203 |
+
if event.message and hasattr(event.message, "text")
|
| 204 |
+
else "No result"
|
| 205 |
+
)
|
| 206 |
+
return AgentEvent(
|
| 207 |
+
type="complete",
|
| 208 |
+
message=final_text,
|
| 209 |
+
data={"iterations": iteration},
|
| 210 |
+
iteration=iteration,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
elif isinstance(event, MagenticAgentDeltaEvent):
|
| 214 |
+
if event.text:
|
| 215 |
+
return AgentEvent(
|
| 216 |
+
type="streaming",
|
| 217 |
+
message=event.text,
|
| 218 |
+
data={"agent_id": event.agent_id},
|
| 219 |
+
iteration=iteration,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
elif isinstance(event, WorkflowOutputEvent):
|
| 223 |
+
if event.data:
|
| 224 |
+
return AgentEvent(
|
| 225 |
+
type="complete",
|
| 226 |
+
message=str(event.data),
|
| 227 |
+
iteration=iteration,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return None
|
| 231 |
+
|
| 232 |
+
def _agent_message_event(self, agent_name: str, msg_text: str, iteration: int) -> AgentEvent:
|
| 233 |
+
"""Create an AgentEvent for an agent message."""
|
| 234 |
+
if "search" in agent_name.lower():
|
| 235 |
+
return AgentEvent(
|
| 236 |
+
type="search_complete",
|
| 237 |
+
message=f"Search agent: {_truncate(msg_text)}",
|
| 238 |
+
iteration=iteration,
|
| 239 |
+
)
|
| 240 |
+
elif "hypothes" in agent_name.lower():
|
| 241 |
+
return AgentEvent(
|
| 242 |
+
type="hypothesizing",
|
| 243 |
+
message=f"Hypothesis agent: {_truncate(msg_text)}",
|
| 244 |
+
iteration=iteration,
|
| 245 |
+
)
|
| 246 |
+
elif "judge" in agent_name.lower():
|
| 247 |
+
return AgentEvent(
|
| 248 |
+
type="judge_complete",
|
| 249 |
+
message=f"Judge agent: {_truncate(msg_text)}",
|
| 250 |
+
iteration=iteration,
|
| 251 |
+
)
|
| 252 |
+
return AgentEvent(
|
| 253 |
+
type="judging",
|
| 254 |
+
message=f"{agent_name}: {_truncate(msg_text)}",
|
| 255 |
+
iteration=iteration,
|
| 256 |
+
)
|
src/prompts/hypothesis.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prompts for Hypothesis Agent."""
|
| 2 |
+
|
| 3 |
+
from typing import TYPE_CHECKING
|
| 4 |
+
|
| 5 |
+
from src.utils.text_utils import select_diverse_evidence, truncate_at_sentence
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from src.services.embeddings import EmbeddingService
|
| 9 |
+
from src.utils.models import Evidence
|
| 10 |
+
|
| 11 |
+
SYSTEM_PROMPT = """You are a biomedical research scientist specializing in drug repurposing.
|
| 12 |
+
|
| 13 |
+
Your role is to generate mechanistic hypotheses based on evidence.
|
| 14 |
+
|
| 15 |
+
A good hypothesis:
|
| 16 |
+
1. Proposes a MECHANISM: Drug -> Target -> Pathway -> Effect
|
| 17 |
+
2. Is TESTABLE: Can be supported or refuted by literature search
|
| 18 |
+
3. Is SPECIFIC: Names actual molecular targets and pathways
|
| 19 |
+
4. Generates SEARCH QUERIES: Helps find more evidence
|
| 20 |
+
|
| 21 |
+
Example hypothesis format:
|
| 22 |
+
- Drug: Metformin
|
| 23 |
+
- Target: AMPK (AMP-activated protein kinase)
|
| 24 |
+
- Pathway: mTOR inhibition -> autophagy activation
|
| 25 |
+
- Effect: Enhanced clearance of amyloid-beta in Alzheimer's
|
| 26 |
+
- Confidence: 0.7
|
| 27 |
+
- Search suggestions: ["metformin AMPK brain", "autophagy amyloid clearance"]
|
| 28 |
+
|
| 29 |
+
Be specific. Use actual gene/protein names when possible."""
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
async def format_hypothesis_prompt(
|
| 33 |
+
query: str, evidence: list["Evidence"], embeddings: "EmbeddingService | None" = None
|
| 34 |
+
) -> str:
|
| 35 |
+
"""Format prompt for hypothesis generation.
|
| 36 |
+
|
| 37 |
+
Uses smart evidence selection instead of arbitrary truncation.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
query: The research query
|
| 41 |
+
evidence: All collected evidence
|
| 42 |
+
embeddings: Optional EmbeddingService for diverse selection
|
| 43 |
+
"""
|
| 44 |
+
# Select diverse, relevant evidence (not arbitrary first 10)
|
| 45 |
+
# We use n=10 as a reasonable context window limit
|
| 46 |
+
selected = await select_diverse_evidence(evidence, n=10, query=query, embeddings=embeddings)
|
| 47 |
+
|
| 48 |
+
# Format with sentence-aware truncation
|
| 49 |
+
evidence_text = "\n".join(
|
| 50 |
+
[
|
| 51 |
+
f"- **{e.citation.title}** ({e.citation.source}): "
|
| 52 |
+
f"{truncate_at_sentence(e.content, 300)}"
|
| 53 |
+
for e in selected
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return f"""Based on the following evidence about "{query}", generate mechanistic hypotheses.
|
| 58 |
+
|
| 59 |
+
## Evidence ({len(selected)} papers selected for diversity)
|
| 60 |
+
{evidence_text}
|
| 61 |
+
|
| 62 |
+
## Task
|
| 63 |
+
1. Identify potential drug targets mentioned in the evidence
|
| 64 |
+
2. Propose mechanism hypotheses (Drug -> Target -> Pathway -> Effect)
|
| 65 |
+
3. Rate confidence based on evidence strength
|
| 66 |
+
4. Suggest searches to test each hypothesis
|
| 67 |
+
|
| 68 |
+
Generate 2-4 hypotheses, prioritized by confidence."""
|
src/utils/models.py
CHANGED
|
@@ -107,6 +107,7 @@ class AgentEvent(BaseModel):
|
|
| 107 |
"complete",
|
| 108 |
"error",
|
| 109 |
"streaming",
|
|
|
|
| 110 |
]
|
| 111 |
message: str
|
| 112 |
data: Any = None
|
|
@@ -126,11 +127,51 @@ class AgentEvent(BaseModel):
|
|
| 126 |
"complete": "🎉",
|
| 127 |
"error": "❌",
|
| 128 |
"streaming": "📡",
|
|
|
|
| 129 |
}
|
| 130 |
icon = icons.get(self.type, "•")
|
| 131 |
return f"{icon} **{self.type.upper()}**: {self.message}"
|
| 132 |
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
class OrchestratorConfig(BaseModel):
|
| 135 |
"""Configuration for the orchestrator."""
|
| 136 |
|
|
|
|
| 107 |
"complete",
|
| 108 |
"error",
|
| 109 |
"streaming",
|
| 110 |
+
"hypothesizing", # NEW for Phase 7
|
| 111 |
]
|
| 112 |
message: str
|
| 113 |
data: Any = None
|
|
|
|
| 127 |
"complete": "🎉",
|
| 128 |
"error": "❌",
|
| 129 |
"streaming": "📡",
|
| 130 |
+
"hypothesizing": "🔬", # NEW
|
| 131 |
}
|
| 132 |
icon = icons.get(self.type, "•")
|
| 133 |
return f"{icon} **{self.type.upper()}**: {self.message}"
|
| 134 |
|
| 135 |
|
| 136 |
+
class MechanismHypothesis(BaseModel):
|
| 137 |
+
"""A scientific hypothesis about drug mechanism."""
|
| 138 |
+
|
| 139 |
+
drug: str = Field(description="The drug being studied")
|
| 140 |
+
target: str = Field(description="Molecular target (e.g., AMPK, mTOR)")
|
| 141 |
+
pathway: str = Field(description="Biological pathway affected")
|
| 142 |
+
effect: str = Field(description="Downstream effect on disease")
|
| 143 |
+
confidence: float = Field(ge=0, le=1, description="Confidence in hypothesis")
|
| 144 |
+
supporting_evidence: list[str] = Field(
|
| 145 |
+
default_factory=list, description="PMIDs or URLs supporting this hypothesis"
|
| 146 |
+
)
|
| 147 |
+
contradicting_evidence: list[str] = Field(
|
| 148 |
+
default_factory=list, description="PMIDs or URLs contradicting this hypothesis"
|
| 149 |
+
)
|
| 150 |
+
search_suggestions: list[str] = Field(
|
| 151 |
+
default_factory=list, description="Suggested searches to test this hypothesis"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def to_search_queries(self) -> list[str]:
|
| 155 |
+
"""Generate search queries to test this hypothesis."""
|
| 156 |
+
return [
|
| 157 |
+
f"{self.drug} {self.target}",
|
| 158 |
+
f"{self.target} {self.pathway}",
|
| 159 |
+
f"{self.pathway} {self.effect}",
|
| 160 |
+
*self.search_suggestions,
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class HypothesisAssessment(BaseModel):
|
| 165 |
+
"""Assessment of evidence against hypotheses."""
|
| 166 |
+
|
| 167 |
+
hypotheses: list[MechanismHypothesis]
|
| 168 |
+
primary_hypothesis: MechanismHypothesis | None = Field(
|
| 169 |
+
default=None, description="Most promising hypothesis based on current evidence"
|
| 170 |
+
)
|
| 171 |
+
knowledge_gaps: list[str] = Field(description="What we don't know yet")
|
| 172 |
+
recommended_searches: list[str] = Field(description="Searches to fill knowledge gaps")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
class OrchestratorConfig(BaseModel):
|
| 176 |
"""Configuration for the orchestrator."""
|
| 177 |
|
src/utils/text_utils.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Text processing utilities for evidence handling."""
|
| 2 |
+
|
| 3 |
+
from typing import TYPE_CHECKING
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from src.services.embeddings import EmbeddingService
|
| 9 |
+
from src.utils.models import Evidence
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def truncate_at_sentence(text: str, max_chars: int = 300) -> str:
|
| 13 |
+
"""Truncate text at sentence boundary, preserving meaning.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
text: The text to truncate
|
| 17 |
+
max_chars: Maximum characters (default 300)
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Text truncated at last complete sentence within limit
|
| 21 |
+
"""
|
| 22 |
+
if len(text) <= max_chars:
|
| 23 |
+
return text
|
| 24 |
+
|
| 25 |
+
# Find truncation point
|
| 26 |
+
truncated = text[:max_chars]
|
| 27 |
+
|
| 28 |
+
# Look for sentence endings: . ! ? followed by space or end
|
| 29 |
+
# We check for sep at the END of the truncated string
|
| 30 |
+
for sep in [". ", "! ", "? ", ".\n", "!\n", "?\n"]:
|
| 31 |
+
last_sep = truncated.rfind(sep)
|
| 32 |
+
if last_sep > max_chars // 2: # Don't truncate too aggressively (less than half)
|
| 33 |
+
return text[: last_sep + 1].strip()
|
| 34 |
+
|
| 35 |
+
# Fallback: find last period (even if not followed by space, e.g. end of string)
|
| 36 |
+
last_period = truncated.rfind(".")
|
| 37 |
+
if last_period > max_chars // 2:
|
| 38 |
+
return text[: last_period + 1].strip()
|
| 39 |
+
|
| 40 |
+
# Last resort: truncate at word boundary
|
| 41 |
+
last_space = truncated.rfind(" ")
|
| 42 |
+
if last_space > 0:
|
| 43 |
+
return text[:last_space].strip() + "..."
|
| 44 |
+
|
| 45 |
+
return truncated + "..."
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
async def select_diverse_evidence(
|
| 49 |
+
evidence: list["Evidence"], n: int, query: str, embeddings: "EmbeddingService | None" = None
|
| 50 |
+
) -> list["Evidence"]:
|
| 51 |
+
"""Select n most diverse and relevant evidence items.
|
| 52 |
+
|
| 53 |
+
Uses Maximal Marginal Relevance (MMR) when embeddings available,
|
| 54 |
+
falls back to relevance_score sorting otherwise.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
evidence: All available evidence
|
| 58 |
+
n: Number of items to select
|
| 59 |
+
query: Original query for relevance scoring
|
| 60 |
+
embeddings: Optional EmbeddingService for semantic diversity
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Selected evidence items, diverse and relevant
|
| 64 |
+
"""
|
| 65 |
+
if not evidence:
|
| 66 |
+
return []
|
| 67 |
+
|
| 68 |
+
if n >= len(evidence):
|
| 69 |
+
return evidence
|
| 70 |
+
|
| 71 |
+
# Fallback: sort by relevance score if no embeddings
|
| 72 |
+
if embeddings is None:
|
| 73 |
+
return sorted(
|
| 74 |
+
evidence,
|
| 75 |
+
key=lambda e: e.relevance, # Use .relevance (from Pydantic model)
|
| 76 |
+
reverse=True,
|
| 77 |
+
)[:n]
|
| 78 |
+
|
| 79 |
+
# MMR: Maximal Marginal Relevance for diverse selection
|
| 80 |
+
# Score = λ * relevance - (1-λ) * max_similarity_to_selected
|
| 81 |
+
lambda_param = 0.7 # Balance relevance vs diversity
|
| 82 |
+
|
| 83 |
+
# Get query embedding
|
| 84 |
+
query_emb = await embeddings.embed(query)
|
| 85 |
+
|
| 86 |
+
# Get all evidence embeddings
|
| 87 |
+
evidence_embs = await embeddings.embed_batch([e.content for e in evidence])
|
| 88 |
+
|
| 89 |
+
# Cosine similarity helper
|
| 90 |
+
def cosine(a: list[float], b: list[float]) -> float:
|
| 91 |
+
arr_a, arr_b = np.array(a), np.array(b)
|
| 92 |
+
denominator = float(np.linalg.norm(arr_a) * np.linalg.norm(arr_b))
|
| 93 |
+
if denominator == 0:
|
| 94 |
+
return 0.0
|
| 95 |
+
return float(np.dot(arr_a, arr_b) / denominator)
|
| 96 |
+
|
| 97 |
+
# Compute relevance scores (cosine similarity to query)
|
| 98 |
+
# Note: We use semantic relevance to query, not the keyword search 'relevance' score
|
| 99 |
+
relevance_scores = [cosine(query_emb, emb) for emb in evidence_embs]
|
| 100 |
+
|
| 101 |
+
# Greedy MMR selection
|
| 102 |
+
selected_indices: list[int] = []
|
| 103 |
+
remaining = set(range(len(evidence)))
|
| 104 |
+
|
| 105 |
+
for _ in range(n):
|
| 106 |
+
best_score = float("-inf")
|
| 107 |
+
best_idx = -1
|
| 108 |
+
|
| 109 |
+
for idx in remaining:
|
| 110 |
+
# Relevance component
|
| 111 |
+
relevance = relevance_scores[idx]
|
| 112 |
+
|
| 113 |
+
# Diversity component: max similarity to already selected
|
| 114 |
+
if selected_indices:
|
| 115 |
+
max_sim = max(
|
| 116 |
+
cosine(evidence_embs[idx], evidence_embs[sel]) for sel in selected_indices
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
max_sim = 0
|
| 120 |
+
|
| 121 |
+
# MMR score
|
| 122 |
+
mmr_score = lambda_param * relevance - (1 - lambda_param) * max_sim
|
| 123 |
+
|
| 124 |
+
if mmr_score > best_score:
|
| 125 |
+
best_score = mmr_score
|
| 126 |
+
best_idx = idx
|
| 127 |
+
|
| 128 |
+
if best_idx >= 0:
|
| 129 |
+
selected_indices.append(best_idx)
|
| 130 |
+
remaining.remove(best_idx)
|
| 131 |
+
|
| 132 |
+
return [evidence[i] for i in selected_indices]
|
tests/unit/agents/test_hypothesis_agent.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for HypothesisAgent."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from agent_framework import AgentRunResponse
|
| 7 |
+
|
| 8 |
+
from src.agents.hypothesis_agent import HypothesisAgent
|
| 9 |
+
from src.utils.models import Citation, Evidence, HypothesisAssessment, MechanismHypothesis
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@pytest.fixture
|
| 13 |
+
def sample_evidence():
|
| 14 |
+
return [
|
| 15 |
+
Evidence(
|
| 16 |
+
content="Metformin activates AMPK, which inhibits mTOR signaling...",
|
| 17 |
+
citation=Citation(
|
| 18 |
+
source="pubmed",
|
| 19 |
+
title="Metformin and AMPK",
|
| 20 |
+
url="https://pubmed.ncbi.nlm.nih.gov/12345/",
|
| 21 |
+
date="2023",
|
| 22 |
+
),
|
| 23 |
+
)
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@pytest.fixture
|
| 28 |
+
def mock_assessment():
|
| 29 |
+
return HypothesisAssessment(
|
| 30 |
+
hypotheses=[
|
| 31 |
+
MechanismHypothesis(
|
| 32 |
+
drug="Metformin",
|
| 33 |
+
target="AMPK",
|
| 34 |
+
pathway="mTOR inhibition",
|
| 35 |
+
effect="Reduced cancer cell proliferation",
|
| 36 |
+
confidence=0.75,
|
| 37 |
+
search_suggestions=["metformin AMPK cancer", "mTOR cancer therapy"],
|
| 38 |
+
)
|
| 39 |
+
],
|
| 40 |
+
primary_hypothesis=None,
|
| 41 |
+
knowledge_gaps=["Clinical trial data needed"],
|
| 42 |
+
recommended_searches=["metformin clinical trial cancer"],
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@pytest.mark.asyncio
|
| 47 |
+
async def test_hypothesis_agent_generates_hypotheses(sample_evidence, mock_assessment):
|
| 48 |
+
"""HypothesisAgent should generate mechanistic hypotheses."""
|
| 49 |
+
store = {"current": sample_evidence, "hypotheses": []}
|
| 50 |
+
|
| 51 |
+
with patch("src.agents.hypothesis_agent.get_model") as mock_get_model:
|
| 52 |
+
with patch("src.agents.hypothesis_agent.Agent") as mock_agent_class:
|
| 53 |
+
mock_get_model.return_value = MagicMock() # Mock model
|
| 54 |
+
mock_result = MagicMock()
|
| 55 |
+
mock_result.output = mock_assessment
|
| 56 |
+
# pydantic-ai Agent returns an object with .output for structured output
|
| 57 |
+
mock_agent_class.return_value.run = AsyncMock(return_value=mock_result)
|
| 58 |
+
|
| 59 |
+
agent = HypothesisAgent(store)
|
| 60 |
+
response = await agent.run("metformin cancer")
|
| 61 |
+
|
| 62 |
+
assert isinstance(response, AgentRunResponse)
|
| 63 |
+
assert "AMPK" in response.messages[0].text
|
| 64 |
+
assert len(store["hypotheses"]) == 1
|
| 65 |
+
assert store["hypotheses"][0].drug == "Metformin"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@pytest.mark.asyncio
|
| 69 |
+
async def test_hypothesis_agent_no_evidence():
|
| 70 |
+
"""HypothesisAgent should handle empty evidence gracefully."""
|
| 71 |
+
store = {"current": [], "hypotheses": []}
|
| 72 |
+
|
| 73 |
+
# No need to mock Agent/get_model - empty evidence returns early
|
| 74 |
+
agent = HypothesisAgent(store)
|
| 75 |
+
response = await agent.run("test query")
|
| 76 |
+
|
| 77 |
+
assert "No evidence" in response.messages[0].text
|
| 78 |
+
assert len(store["hypotheses"]) == 0
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@pytest.mark.asyncio
|
| 82 |
+
async def test_hypothesis_agent_uses_embeddings(sample_evidence, mock_assessment):
|
| 83 |
+
"""HypothesisAgent should pass embeddings to prompt formatter."""
|
| 84 |
+
store = {"current": sample_evidence, "hypotheses": []}
|
| 85 |
+
mock_embeddings = MagicMock()
|
| 86 |
+
|
| 87 |
+
with patch("src.agents.hypothesis_agent.get_model") as mock_get_model:
|
| 88 |
+
with patch("src.agents.hypothesis_agent.Agent") as mock_agent_class:
|
| 89 |
+
# Mock format_hypothesis_prompt to check if embeddings were passed
|
| 90 |
+
with patch("src.agents.hypothesis_agent.format_hypothesis_prompt") as mock_format:
|
| 91 |
+
mock_get_model.return_value = MagicMock() # Mock model
|
| 92 |
+
mock_format.return_value = "Prompt"
|
| 93 |
+
|
| 94 |
+
mock_result = MagicMock()
|
| 95 |
+
mock_result.output = mock_assessment
|
| 96 |
+
mock_agent_class.return_value.run = AsyncMock(return_value=mock_result)
|
| 97 |
+
|
| 98 |
+
agent = HypothesisAgent(store, embedding_service=mock_embeddings)
|
| 99 |
+
await agent.run("query")
|
| 100 |
+
|
| 101 |
+
mock_format.assert_called_once()
|
| 102 |
+
_args, kwargs = mock_format.call_args
|
| 103 |
+
assert kwargs["embeddings"] == mock_embeddings
|
| 104 |
+
assert _args[0] == "query" # First positional arg is query
|
| 105 |
+
assert _args[1] == sample_evidence # Second positional arg is evidence
|
tests/unit/utils/test_text_utils.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for text utilities."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import AsyncMock, MagicMock
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from src.utils.models import Citation, Evidence
|
| 8 |
+
from src.utils.text_utils import select_diverse_evidence, truncate_at_sentence
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestTextUtils:
|
| 12 |
+
def test_truncate_at_sentence_short(self):
|
| 13 |
+
"""Should return text as is if shorter than limit."""
|
| 14 |
+
text = "This is a short sentence."
|
| 15 |
+
assert truncate_at_sentence(text, 100) == text
|
| 16 |
+
|
| 17 |
+
def test_truncate_at_sentence_boundary(self):
|
| 18 |
+
"""Should truncate at sentence ending."""
|
| 19 |
+
text = "First sentence. Second sentence. Third sentence."
|
| 20 |
+
# Limit should cut in the middle of second sentence
|
| 21 |
+
limit = len("First sentence. Second sentence") + 5
|
| 22 |
+
result = truncate_at_sentence(text, limit)
|
| 23 |
+
assert result == "First sentence. Second sentence."
|
| 24 |
+
|
| 25 |
+
def test_truncate_at_sentence_fallback_period(self):
|
| 26 |
+
"""Should fall back to period if no sentence boundary found."""
|
| 27 |
+
text = "Dr. Smith went to the store. He bought apples."
|
| 28 |
+
# Limit cuts in "He bought"
|
| 29 |
+
limit = len("Dr. Smith went to the store.") + 5
|
| 30 |
+
result = truncate_at_sentence(text, limit)
|
| 31 |
+
assert result == "Dr. Smith went to the store."
|
| 32 |
+
|
| 33 |
+
def test_truncate_at_sentence_fallback_word(self):
|
| 34 |
+
"""Should fall back to word boundary if no punctuation."""
|
| 35 |
+
text = "This is a very long sentence without any punctuation marks until the very end"
|
| 36 |
+
limit = 20
|
| 37 |
+
result = truncate_at_sentence(text, limit)
|
| 38 |
+
assert result == "This is a very long..."
|
| 39 |
+
# Ellipsis might add length, checking logic
|
| 40 |
+
assert len(result) <= limit + 3 # noqa: PLR2004
|
| 41 |
+
|
| 42 |
+
@pytest.mark.asyncio
|
| 43 |
+
async def test_select_diverse_evidence_no_embeddings(self):
|
| 44 |
+
"""Should fallback to relevance sort if no embeddings."""
|
| 45 |
+
evidence = [
|
| 46 |
+
Evidence(
|
| 47 |
+
content="A",
|
| 48 |
+
relevance=0.9,
|
| 49 |
+
citation=Citation(source="web", title="A", url="a", date="2023"),
|
| 50 |
+
),
|
| 51 |
+
Evidence(
|
| 52 |
+
content="B",
|
| 53 |
+
relevance=0.1,
|
| 54 |
+
citation=Citation(source="web", title="B", url="b", date="2023"),
|
| 55 |
+
),
|
| 56 |
+
Evidence(
|
| 57 |
+
content="C",
|
| 58 |
+
relevance=0.8,
|
| 59 |
+
citation=Citation(source="web", title="C", url="c", date="2023"),
|
| 60 |
+
),
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
selected = await select_diverse_evidence(evidence, n=2, query="test", embeddings=None)
|
| 64 |
+
|
| 65 |
+
expected_count = 2
|
| 66 |
+
assert len(selected) == expected_count
|
| 67 |
+
assert selected[0].content == "A" # Highest relevance
|
| 68 |
+
assert selected[1].content == "C" # Second highest
|
| 69 |
+
|
| 70 |
+
@pytest.mark.asyncio
|
| 71 |
+
async def test_select_diverse_evidence_mmr(self):
|
| 72 |
+
"""Should select diverse evidence using MMR."""
|
| 73 |
+
# Mock embeddings
|
| 74 |
+
mock_embeddings = MagicMock()
|
| 75 |
+
|
| 76 |
+
# Scenario: Query is equidistant to A and C.
|
| 77 |
+
# A and B are identical (clones).
|
| 78 |
+
# C is orthogonal to A/B.
|
| 79 |
+
# We expect A (first) then C (diverse), skipping B (clone).
|
| 80 |
+
|
| 81 |
+
# Query: [0.707, 0.707]
|
| 82 |
+
# A: [1.0, 0.0] -> Sim to Q: 0.707
|
| 83 |
+
# B: [1.0, 0.0] -> Sim to Q: 0.707, Sim to A: 1.0
|
| 84 |
+
# C: [0.0, 1.0] -> Sim to Q: 0.707, Sim to A: 0.0
|
| 85 |
+
|
| 86 |
+
async def mock_embed(text):
|
| 87 |
+
if text == "query":
|
| 88 |
+
return [0.707, 0.707]
|
| 89 |
+
return [0.0, 0.0]
|
| 90 |
+
|
| 91 |
+
async def mock_embed_batch(texts):
|
| 92 |
+
results = []
|
| 93 |
+
for t in texts:
|
| 94 |
+
if t == "A":
|
| 95 |
+
results.append([1.0, 0.0])
|
| 96 |
+
elif t == "B":
|
| 97 |
+
results.append([1.0, 0.0]) # Clone of A
|
| 98 |
+
elif t == "C":
|
| 99 |
+
results.append([0.0, 1.0]) # Orthogonal
|
| 100 |
+
else:
|
| 101 |
+
results.append([0.0, 0.0])
|
| 102 |
+
return results
|
| 103 |
+
|
| 104 |
+
mock_embeddings.embed = AsyncMock(side_effect=mock_embed)
|
| 105 |
+
mock_embeddings.embed_batch = AsyncMock(side_effect=mock_embed_batch)
|
| 106 |
+
|
| 107 |
+
evidence = [
|
| 108 |
+
Evidence(
|
| 109 |
+
content="A",
|
| 110 |
+
relevance=0.9,
|
| 111 |
+
citation=Citation(source="web", title="A", url="a", date="2023"),
|
| 112 |
+
),
|
| 113 |
+
Evidence(
|
| 114 |
+
content="B",
|
| 115 |
+
relevance=0.9,
|
| 116 |
+
citation=Citation(source="web", title="B", url="b", date="2023"),
|
| 117 |
+
),
|
| 118 |
+
Evidence(
|
| 119 |
+
content="C",
|
| 120 |
+
relevance=0.9,
|
| 121 |
+
citation=Citation(source="web", title="C", url="c", date="2023"),
|
| 122 |
+
),
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
# With n=2, we expect A then C.
|
| 126 |
+
selected = await select_diverse_evidence(
|
| 127 |
+
evidence, n=2, query="query", embeddings=mock_embeddings
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
expected_count = 2
|
| 131 |
+
assert len(selected) == expected_count
|
| 132 |
+
assert selected[0].content == "A"
|
| 133 |
+
assert selected[1].content == "C"
|