VibecoderMcSwaggins commited on
Commit
1bc9785
·
1 Parent(s): 1efef06

refactor: address CodeRabbit nitpicks for Phase 10

Browse files

- Centralize SourceName type alias in models.py
- Update search_handler.py to use SourceName (reduces cast brittleness)
- Update app.py intro text to mention ClinicalTrials.gov
- Update run_search.py docstring for multi-source search
- DRY test setup with mock_requests_get fixture
- Speed up error test by patching tenacity retry stop condition

examples/search_demo/run_search.py CHANGED
@@ -2,8 +2,9 @@
2
  """
3
  Demo: Search for drug repurposing evidence.
4
 
5
- This script demonstrates Phase 2 functionality:
6
  - PubMed search (biomedical literature)
 
7
  - SearchHandler (parallel scatter-gather orchestration)
8
 
9
  Usage:
 
2
  """
3
  Demo: Search for drug repurposing evidence.
4
 
5
+ This script demonstrates multi-source search functionality:
6
  - PubMed search (biomedical literature)
7
+ - ClinicalTrials.gov search (clinical trial evidence)
8
  - SearchHandler (parallel scatter-gather orchestration)
9
 
10
  Usage:
src/app.py CHANGED
@@ -128,7 +128,7 @@ def create_demo() -> Any:
128
  ## AI-Powered Drug Repurposing Research Agent
129
 
130
  Ask questions about potential drug repurposing opportunities.
131
- The agent will search PubMed, evaluate evidence, and provide recommendations.
132
 
133
  **Example questions:**
134
  - "What drugs could be repurposed for Alzheimer's disease?"
 
128
  ## AI-Powered Drug Repurposing Research Agent
129
 
130
  Ask questions about potential drug repurposing opportunities.
131
+ The agent searches PubMed & ClinicalTrials.gov to provide recommendations.
132
 
133
  **Example questions:**
134
  - "What drugs could be repurposed for Alzheimer's disease?"
src/tools/search_handler.py CHANGED
@@ -1,13 +1,13 @@
1
  """Search handler - orchestrates multiple search tools."""
2
 
3
  import asyncio
4
- from typing import Literal, cast
5
 
6
  import structlog
7
 
8
  from src.tools.base import SearchTool
9
  from src.utils.exceptions import SearchError
10
- from src.utils.models import Evidence, SearchResult
11
 
12
  logger = structlog.get_logger()
13
 
@@ -49,7 +49,7 @@ class SearchHandler:
49
 
50
  # Process results
51
  all_evidence: list[Evidence] = []
52
- sources_searched: list[Literal["pubmed", "clinicaltrials"]] = []
53
  errors: list[str] = []
54
 
55
  for tool, result in zip(self.tools, results, strict=True):
@@ -61,8 +61,8 @@ class SearchHandler:
61
  success_result = cast(list[Evidence], result)
62
  all_evidence.extend(success_result)
63
 
64
- # Cast tool.name to the expected Literal
65
- tool_name = cast(Literal["pubmed", "clinicaltrials"], tool.name)
66
  sources_searched.append(tool_name)
67
  logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
68
 
 
1
  """Search handler - orchestrates multiple search tools."""
2
 
3
  import asyncio
4
+ from typing import cast
5
 
6
  import structlog
7
 
8
  from src.tools.base import SearchTool
9
  from src.utils.exceptions import SearchError
10
+ from src.utils.models import Evidence, SearchResult, SourceName
11
 
12
  logger = structlog.get_logger()
13
 
 
49
 
50
  # Process results
51
  all_evidence: list[Evidence] = []
52
+ sources_searched: list[SourceName] = []
53
  errors: list[str] = []
54
 
55
  for tool, result in zip(self.tools, results, strict=True):
 
61
  success_result = cast(list[Evidence], result)
62
  all_evidence.extend(success_result)
63
 
64
+ # Cast tool.name to SourceName (centralized type from models)
65
+ tool_name = cast(SourceName, tool.name)
66
  sources_searched.append(tool_name)
67
  logger.info("Search tool succeeded", tool=tool.name, count=len(success_result))
68
 
src/utils/models.py CHANGED
@@ -5,11 +5,14 @@ from typing import Any, ClassVar, Literal
5
 
6
  from pydantic import BaseModel, Field
7
 
 
 
 
8
 
9
  class Citation(BaseModel):
10
  """A citation to a source document."""
11
 
12
- source: Literal["pubmed", "clinicaltrials"] = Field(description="Where this came from")
13
  title: str = Field(min_length=1, max_length=500)
14
  url: str = Field(description="URL to the source")
15
  date: str = Field(description="Publication date (YYYY-MM-DD or 'Unknown')")
@@ -41,7 +44,7 @@ class SearchResult(BaseModel):
41
 
42
  query: str
43
  evidence: list[Evidence]
44
- sources_searched: list[Literal["pubmed", "clinicaltrials"]]
45
  total_found: int
46
  errors: list[str] = Field(default_factory=list)
47
 
 
5
 
6
  from pydantic import BaseModel, Field
7
 
8
+ # Centralized source type - add new sources here (e.g., "biorxiv" in Phase 11)
9
+ SourceName = Literal["pubmed", "clinicaltrials"]
10
+
11
 
12
  class Citation(BaseModel):
13
  """A citation to a source document."""
14
 
15
+ source: SourceName = Field(description="Where this came from")
16
  title: str = Field(min_length=1, max_length=500)
17
  url: str = Field(description="URL to the source")
18
  date: str = Field(description="Publication date (YYYY-MM-DD or 'Unknown')")
 
44
 
45
  query: str
46
  evidence: list[Evidence]
47
+ sources_searched: list[SourceName]
48
  total_found: int
49
  errors: list[str] = Field(default_factory=list)
50
 
tests/unit/tools/test_clinicaltrials.py CHANGED
@@ -1,5 +1,7 @@
1
  """Unit tests for ClinicalTrials.gov tool."""
2
 
 
 
3
  from unittest.mock import MagicMock, patch
4
 
5
  import pytest
@@ -11,7 +13,7 @@ from src.utils.models import Evidence
11
 
12
 
13
  @pytest.fixture
14
- def mock_clinicaltrials_response() -> dict:
15
  """Mock ClinicalTrials.gov API response."""
16
  return {
17
  "studies": [
@@ -39,6 +41,19 @@ def mock_clinicaltrials_response() -> dict:
39
  }
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class TestClinicalTrialsTool:
43
  """Tests for ClinicalTrialsTool."""
44
 
@@ -48,50 +63,32 @@ class TestClinicalTrialsTool:
48
  assert tool.name == "clinicaltrials"
49
 
50
  @pytest.mark.asyncio
51
- async def test_search_returns_evidence(self, mock_clinicaltrials_response: dict) -> None:
52
  """Search should return Evidence objects."""
53
- with patch("src.tools.clinicaltrials.requests.get") as mock_get:
54
- mock_response = MagicMock()
55
- mock_response.json.return_value = mock_clinicaltrials_response
56
- mock_response.raise_for_status = MagicMock()
57
- mock_get.return_value = mock_response
58
-
59
- tool = ClinicalTrialsTool()
60
- results = await tool.search("metformin alzheimer", max_results=5)
61
 
62
- assert len(results) == 1
63
- assert isinstance(results[0], Evidence)
64
- assert results[0].citation.source == "clinicaltrials"
65
- assert "NCT04098666" in results[0].citation.url
66
- assert "Metformin" in results[0].citation.title
67
 
68
  @pytest.mark.asyncio
69
- async def test_search_extracts_phase(self, mock_clinicaltrials_response: dict) -> None:
70
  """Search should extract trial phase."""
71
- with patch("src.tools.clinicaltrials.requests.get") as mock_get:
72
- mock_response = MagicMock()
73
- mock_response.json.return_value = mock_clinicaltrials_response
74
- mock_response.raise_for_status = MagicMock()
75
- mock_get.return_value = mock_response
76
-
77
- tool = ClinicalTrialsTool()
78
- results = await tool.search("metformin alzheimer")
79
 
80
- assert "PHASE2" in results[0].content
81
 
82
  @pytest.mark.asyncio
83
- async def test_search_extracts_status(self, mock_clinicaltrials_response: dict) -> None:
84
  """Search should extract trial status."""
85
- with patch("src.tools.clinicaltrials.requests.get") as mock_get:
86
- mock_response = MagicMock()
87
- mock_response.json.return_value = mock_clinicaltrials_response
88
- mock_response.raise_for_status = MagicMock()
89
- mock_get.return_value = mock_response
90
-
91
- tool = ClinicalTrialsTool()
92
- results = await tool.search("metformin alzheimer")
93
 
94
- assert "Recruiting" in results[0].content
95
 
96
  @pytest.mark.asyncio
97
  async def test_search_empty_results(self) -> None:
@@ -109,13 +106,18 @@ class TestClinicalTrialsTool:
109
 
110
  @pytest.mark.asyncio
111
  async def test_search_api_error(self) -> None:
112
- """Search should raise SearchError on API failure."""
 
 
 
113
  with patch("src.tools.clinicaltrials.requests.get") as mock_get:
114
  mock_response = MagicMock()
115
  mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
116
  mock_get.return_value = mock_response
117
 
118
  tool = ClinicalTrialsTool()
 
 
119
 
120
  with pytest.raises(SearchError):
121
  await tool.search("metformin alzheimer")
 
1
  """Unit tests for ClinicalTrials.gov tool."""
2
 
3
+ from collections.abc import Generator
4
+ from typing import Any
5
  from unittest.mock import MagicMock, patch
6
 
7
  import pytest
 
13
 
14
 
15
  @pytest.fixture
16
+ def mock_clinicaltrials_response() -> dict[str, Any]:
17
  """Mock ClinicalTrials.gov API response."""
18
  return {
19
  "studies": [
 
41
  }
42
 
43
 
44
+ @pytest.fixture
45
+ def mock_requests_get(
46
+ mock_clinicaltrials_response: dict[str, Any],
47
+ ) -> Generator[MagicMock, None, None]:
48
+ """Fixture to mock requests.get with a successful response."""
49
+ with patch("src.tools.clinicaltrials.requests.get") as mock_get:
50
+ mock_response = MagicMock()
51
+ mock_response.json.return_value = mock_clinicaltrials_response
52
+ mock_response.raise_for_status = MagicMock()
53
+ mock_get.return_value = mock_response
54
+ yield mock_get
55
+
56
+
57
  class TestClinicalTrialsTool:
58
  """Tests for ClinicalTrialsTool."""
59
 
 
63
  assert tool.name == "clinicaltrials"
64
 
65
  @pytest.mark.asyncio
66
+ async def test_search_returns_evidence(self, mock_requests_get: MagicMock) -> None:
67
  """Search should return Evidence objects."""
68
+ tool = ClinicalTrialsTool()
69
+ results = await tool.search("metformin alzheimer", max_results=5)
 
 
 
 
 
 
70
 
71
+ assert len(results) == 1
72
+ assert isinstance(results[0], Evidence)
73
+ assert results[0].citation.source == "clinicaltrials"
74
+ assert "NCT04098666" in results[0].citation.url
75
+ assert "Metformin" in results[0].citation.title
76
 
77
  @pytest.mark.asyncio
78
+ async def test_search_extracts_phase(self, mock_requests_get: MagicMock) -> None:
79
  """Search should extract trial phase."""
80
+ tool = ClinicalTrialsTool()
81
+ results = await tool.search("metformin alzheimer")
 
 
 
 
 
 
82
 
83
+ assert "PHASE2" in results[0].content
84
 
85
  @pytest.mark.asyncio
86
+ async def test_search_extracts_status(self, mock_requests_get: MagicMock) -> None:
87
  """Search should extract trial status."""
88
+ tool = ClinicalTrialsTool()
89
+ results = await tool.search("metformin alzheimer")
 
 
 
 
 
 
90
 
91
+ assert "Recruiting" in results[0].content
92
 
93
  @pytest.mark.asyncio
94
  async def test_search_empty_results(self) -> None:
 
106
 
107
  @pytest.mark.asyncio
108
  async def test_search_api_error(self) -> None:
109
+ """Search should raise SearchError on API failure.
110
+
111
+ Note: We patch the retry decorator to avoid 3x backoff delay in tests.
112
+ """
113
  with patch("src.tools.clinicaltrials.requests.get") as mock_get:
114
  mock_response = MagicMock()
115
  mock_response.raise_for_status.side_effect = requests.HTTPError("500 Server Error")
116
  mock_get.return_value = mock_response
117
 
118
  tool = ClinicalTrialsTool()
119
+ # Patch the retry decorator's stop condition to fail immediately
120
+ tool.search.retry.stop = lambda _: True # type: ignore[attr-defined]
121
 
122
  with pytest.raises(SearchError):
123
  await tool.search("metformin alzheimer")