Spaces:
Running
Running
Commit
·
1bc9785
1
Parent(s):
1efef06
refactor: address CodeRabbit nitpicks for Phase 10
Browse files- Centralize SourceName type alias in models.py
- Update search_handler.py to use SourceName (reduces cast brittleness)
- Update app.py intro text to mention ClinicalTrials.gov
- Update run_search.py docstring for multi-source search
- DRY test setup with mock_requests_get fixture
- Speed up error test by patching tenacity retry stop condition
- examples/search_demo/run_search.py +2 -1
- src/app.py +1 -1
- src/tools/search_handler.py +5 -5
- src/utils/models.py +5 -2
- tests/unit/tools/test_clinicaltrials.py +38 -36
examples/search_demo/run_search.py
CHANGED
|
@@ -2,8 +2,9 @@
|
|
| 2 |
"""
|
| 3 |
Demo: Search for drug repurposing evidence.
|
| 4 |
|
| 5 |
-
This script demonstrates
|
| 6 |
- PubMed search (biomedical literature)
|
|
|
|
| 7 |
- SearchHandler (parallel scatter-gather orchestration)
|
| 8 |
|
| 9 |
Usage:
|
|
|
|
| 2 |
"""
|
| 3 |
Demo: Search for drug repurposing evidence.
|
| 4 |
|
| 5 |
+
This script demonstrates multi-source search functionality:
|
| 6 |
- PubMed search (biomedical literature)
|
| 7 |
+
- ClinicalTrials.gov search (clinical trial evidence)
|
| 8 |
- SearchHandler (parallel scatter-gather orchestration)
|
| 9 |
|
| 10 |
Usage:
|
src/app.py
CHANGED
|
@@ -128,7 +128,7 @@ def create_demo() -> Any:
|
|
| 128 |
## AI-Powered Drug Repurposing Research Agent
|
| 129 |
|
| 130 |
Ask questions about potential drug repurposing opportunities.
|
| 131 |
-
The agent
|
| 132 |
|
| 133 |
**Example questions:**
|
| 134 |
- "What drugs could be repurposed for Alzheimer's disease?"
|
|
|
|
| 128 |
## AI-Powered Drug Repurposing Research Agent
|
| 129 |
|
| 130 |
Ask questions about potential drug repurposing opportunities.
|
| 131 |
+
The agent searches PubMed & ClinicalTrials.gov to provide recommendations.
|
| 132 |
|
| 133 |
**Example questions:**
|
| 134 |
- "What drugs could be repurposed for Alzheimer's disease?"
|
src/tools/search_handler.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
"""Search handler - orchestrates multiple search tools."""
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
-
from typing import
|
| 5 |
|
| 6 |
import structlog
|
| 7 |
|
| 8 |
from src.tools.base import SearchTool
|
| 9 |
from src.utils.exceptions import SearchError
|
| 10 |
-
from src.utils.models import Evidence, SearchResult
|
| 11 |
|
| 12 |
logger = structlog.get_logger()
|
| 13 |
|
|
@@ -49,7 +49,7 @@ class SearchHandler:
|
|
| 49 |
|
| 50 |
# Process results
|
| 51 |
all_evidence: list[Evidence] = []
|
| 52 |
-
sources_searched: list[
|
| 53 |
errors: list[str] = []
|
| 54 |
|
| 55 |
for tool, result in zip(self.tools, results, strict=True):
|
|
@@ -61,8 +61,8 @@ class SearchHandler:
|
|
| 61 |
success_result = cast(list[Evidence], result)
|
| 62 |
all_evidence.extend(success_result)
|
| 63 |
|
| 64 |
-
# Cast tool.name to
|
| 65 |
-
tool_name = cast(
|
| 66 |
sources_searched.append(tool_name)
|
| 67 |
logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
|
| 68 |
|
|
|
|
| 1 |
"""Search handler - orchestrates multiple search tools."""
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
+
from typing import cast
|
| 5 |
|
| 6 |
import structlog
|
| 7 |
|
| 8 |
from src.tools.base import SearchTool
|
| 9 |
from src.utils.exceptions import SearchError
|
| 10 |
+
from src.utils.models import Evidence, SearchResult, SourceName
|
| 11 |
|
| 12 |
logger = structlog.get_logger()
|
| 13 |
|
|
|
|
| 49 |
|
| 50 |
# Process results
|
| 51 |
all_evidence: list[Evidence] = []
|
| 52 |
+
sources_searched: list[SourceName] = []
|
| 53 |
errors: list[str] = []
|
| 54 |
|
| 55 |
for tool, result in zip(self.tools, results, strict=True):
|
|
|
|
| 61 |
success_result = cast(list[Evidence], result)
|
| 62 |
all_evidence.extend(success_result)
|
| 63 |
|
| 64 |
+
# Cast tool.name to SourceName (centralized type from models)
|
| 65 |
+
tool_name = cast(SourceName, tool.name)
|
| 66 |
sources_searched.append(tool_name)
|
| 67 |
logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
|
| 68 |
|
src/utils/models.py
CHANGED
|
@@ -5,11 +5,14 @@ from typing import Any, ClassVar, Literal
|
|
| 5 |
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class Citation(BaseModel):
|
| 10 |
"""A citation to a source document."""
|
| 11 |
|
| 12 |
-
source:
|
| 13 |
title: str = Field(min_length=1, max_length=500)
|
| 14 |
url: str = Field(description="URL to the source")
|
| 15 |
date: str = Field(description="Publication date (YYYY-MM-DD or 'Unknown')")
|
|
@@ -41,7 +44,7 @@ class SearchResult(BaseModel):
|
|
| 41 |
|
| 42 |
query: str
|
| 43 |
evidence: list[Evidence]
|
| 44 |
-
sources_searched: list[
|
| 45 |
total_found: int
|
| 46 |
errors: list[str] = Field(default_factory=list)
|
| 47 |
|
|
|
|
| 5 |
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
|
| 8 |
+
# Centralized source type - add new sources here (e.g., "biorxiv" in Phase 11)
|
| 9 |
+
SourceName = Literal["pubmed", "clinicaltrials"]
|
| 10 |
+
|
| 11 |
|
| 12 |
class Citation(BaseModel):
|
| 13 |
"""A citation to a source document."""
|
| 14 |
|
| 15 |
+
source: SourceName = Field(description="Where this came from")
|
| 16 |
title: str = Field(min_length=1, max_length=500)
|
| 17 |
url: str = Field(description="URL to the source")
|
| 18 |
date: str = Field(description="Publication date (YYYY-MM-DD or 'Unknown')")
|
|
|
|
| 44 |
|
| 45 |
query: str
|
| 46 |
evidence: list[Evidence]
|
| 47 |
+
sources_searched: list[SourceName]
|
| 48 |
total_found: int
|
| 49 |
errors: list[str] = Field(default_factory=list)
|
| 50 |
|
tests/unit/tools/test_clinicaltrials.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
"""Unit tests for ClinicalTrials.gov tool."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
from unittest.mock import MagicMock, patch
|
| 4 |
|
| 5 |
import pytest
|
|
@@ -11,7 +13,7 @@ from src.utils.models import Evidence
|
|
| 11 |
|
| 12 |
|
| 13 |
@pytest.fixture
|
| 14 |
-
def mock_clinicaltrials_response() -> dict:
|
| 15 |
"""Mock ClinicalTrials.gov API response."""
|
| 16 |
return {
|
| 17 |
"studies": [
|
|
@@ -39,6 +41,19 @@ def mock_clinicaltrials_response() -> dict:
|
|
| 39 |
}
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
class TestClinicalTrialsTool:
|
| 43 |
"""Tests for ClinicalTrialsTool."""
|
| 44 |
|
|
@@ -48,50 +63,32 @@ class TestClinicalTrialsTool:
|
|
| 48 |
assert tool.name == "clinicaltrials"
|
| 49 |
|
| 50 |
@pytest.mark.asyncio
|
| 51 |
-
async def test_search_returns_evidence(self,
|
| 52 |
"""Search should return Evidence objects."""
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
mock_response.json.return_value = mock_clinicaltrials_response
|
| 56 |
-
mock_response.raise_for_status = MagicMock()
|
| 57 |
-
mock_get.return_value = mock_response
|
| 58 |
-
|
| 59 |
-
tool = ClinicalTrialsTool()
|
| 60 |
-
results = await tool.search("metformin alzheimer", max_results=5)
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
|
| 68 |
@pytest.mark.asyncio
|
| 69 |
-
async def test_search_extracts_phase(self,
|
| 70 |
"""Search should extract trial phase."""
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
mock_response.json.return_value = mock_clinicaltrials_response
|
| 74 |
-
mock_response.raise_for_status = MagicMock()
|
| 75 |
-
mock_get.return_value = mock_response
|
| 76 |
-
|
| 77 |
-
tool = ClinicalTrialsTool()
|
| 78 |
-
results = await tool.search("metformin alzheimer")
|
| 79 |
|
| 80 |
-
|
| 81 |
|
| 82 |
@pytest.mark.asyncio
|
| 83 |
-
async def test_search_extracts_status(self,
|
| 84 |
"""Search should extract trial status."""
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
mock_response.json.return_value = mock_clinicaltrials_response
|
| 88 |
-
mock_response.raise_for_status = MagicMock()
|
| 89 |
-
mock_get.return_value = mock_response
|
| 90 |
-
|
| 91 |
-
tool = ClinicalTrialsTool()
|
| 92 |
-
results = await tool.search("metformin alzheimer")
|
| 93 |
|
| 94 |
-
|
| 95 |
|
| 96 |
@pytest.mark.asyncio
|
| 97 |
async def test_search_empty_results(self) -> None:
|
|
@@ -109,13 +106,18 @@ class TestClinicalTrialsTool:
|
|
| 109 |
|
| 110 |
@pytest.mark.asyncio
|
| 111 |
async def test_search_api_error(self) -> None:
|
| 112 |
-
"""Search should raise SearchError on API failure.
|
|
|
|
|
|
|
|
|
|
| 113 |
with patch("src.tools.clinicaltrials.requests.get") as mock_get:
|
| 114 |
mock_response = MagicMock()
|
| 115 |
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
|
| 116 |
mock_get.return_value = mock_response
|
| 117 |
|
| 118 |
tool = ClinicalTrialsTool()
|
|
|
|
|
|
|
| 119 |
|
| 120 |
with pytest.raises(SearchError):
|
| 121 |
await tool.search("metformin alzheimer")
|
|
|
|
| 1 |
"""Unit tests for ClinicalTrials.gov tool."""
|
| 2 |
|
| 3 |
+
from collections.abc import Generator
|
| 4 |
+
from typing import Any
|
| 5 |
from unittest.mock import MagicMock, patch
|
| 6 |
|
| 7 |
import pytest
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
@pytest.fixture
|
| 16 |
+
def mock_clinicaltrials_response() -> dict[str, Any]:
|
| 17 |
"""Mock ClinicalTrials.gov API response."""
|
| 18 |
return {
|
| 19 |
"studies": [
|
|
|
|
| 41 |
}
|
| 42 |
|
| 43 |
|
| 44 |
+
@pytest.fixture
|
| 45 |
+
def mock_requests_get(
|
| 46 |
+
mock_clinicaltrials_response: dict[str, Any],
|
| 47 |
+
) -> Generator[MagicMock, None, None]:
|
| 48 |
+
"""Fixture to mock requests.get with a successful response."""
|
| 49 |
+
with patch("src.tools.clinicaltrials.requests.get") as mock_get:
|
| 50 |
+
mock_response = MagicMock()
|
| 51 |
+
mock_response.json.return_value = mock_clinicaltrials_response
|
| 52 |
+
mock_response.raise_for_status = MagicMock()
|
| 53 |
+
mock_get.return_value = mock_response
|
| 54 |
+
yield mock_get
|
| 55 |
+
|
| 56 |
+
|
| 57 |
class TestClinicalTrialsTool:
|
| 58 |
"""Tests for ClinicalTrialsTool."""
|
| 59 |
|
|
|
|
| 63 |
assert tool.name == "clinicaltrials"
|
| 64 |
|
| 65 |
@pytest.mark.asyncio
|
| 66 |
+
async def test_search_returns_evidence(self, mock_requests_get: MagicMock) -> None:
|
| 67 |
"""Search should return Evidence objects."""
|
| 68 |
+
tool = ClinicalTrialsTool()
|
| 69 |
+
results = await tool.search("metformin alzheimer", max_results=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
assert len(results) == 1
|
| 72 |
+
assert isinstance(results[0], Evidence)
|
| 73 |
+
assert results[0].citation.source == "clinicaltrials"
|
| 74 |
+
assert "NCT04098666" in results[0].citation.url
|
| 75 |
+
assert "Metformin" in results[0].citation.title
|
| 76 |
|
| 77 |
@pytest.mark.asyncio
|
| 78 |
+
async def test_search_extracts_phase(self, mock_requests_get: MagicMock) -> None:
|
| 79 |
"""Search should extract trial phase."""
|
| 80 |
+
tool = ClinicalTrialsTool()
|
| 81 |
+
results = await tool.search("metformin alzheimer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
assert "PHASE2" in results[0].content
|
| 84 |
|
| 85 |
@pytest.mark.asyncio
|
| 86 |
+
async def test_search_extracts_status(self, mock_requests_get: MagicMock) -> None:
|
| 87 |
"""Search should extract trial status."""
|
| 88 |
+
tool = ClinicalTrialsTool()
|
| 89 |
+
results = await tool.search("metformin alzheimer")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
assert "Recruiting" in results[0].content
|
| 92 |
|
| 93 |
@pytest.mark.asyncio
|
| 94 |
async def test_search_empty_results(self) -> None:
|
|
|
|
| 106 |
|
| 107 |
@pytest.mark.asyncio
|
| 108 |
async def test_search_api_error(self) -> None:
|
| 109 |
+
"""Search should raise SearchError on API failure.
|
| 110 |
+
|
| 111 |
+
Note: We patch the retry decorator to avoid 3x backoff delay in tests.
|
| 112 |
+
"""
|
| 113 |
with patch("src.tools.clinicaltrials.requests.get") as mock_get:
|
| 114 |
mock_response = MagicMock()
|
| 115 |
mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
|
| 116 |
mock_get.return_value = mock_response
|
| 117 |
|
| 118 |
tool = ClinicalTrialsTool()
|
| 119 |
+
# Patch the retry decorator's stop condition to fail immediately
|
| 120 |
+
tool.search.retry.stop = lambda _: True # type: ignore[attr-defined]
|
| 121 |
|
| 122 |
with pytest.raises(SearchError):
|
| 123 |
await tool.search("metformin alzheimer")
|