VibecoderMcSwaggins commited on
Commit
e502f0d
·
1 Parent(s): 2f8ae1f

feat(search): implement PubMed query preprocessing (Phase 02)

Browse files

Added query preprocessing to strip question words and expand medical synonyms (e.g., 'Long COVID' -> 'PASC'). This fixes poor search results for natural language queries.

src/tools/pubmed.py CHANGED
@@ -7,6 +7,7 @@ import httpx
7
  import xmltodict
8
  from tenacity import retry, stop_after_attempt, wait_exponential
9
 
 
10
  from src.utils.config import settings
11
  from src.utils.exceptions import RateLimitError, SearchError
12
  from src.utils.models import Citation, Evidence
@@ -61,11 +62,15 @@ class PubMedTool:
61
  """
62
  await self._rate_limit()
63
 
 
 
 
 
64
  async with httpx.AsyncClient(timeout=30.0) as client:
65
  # Step 1: Search for PMIDs
66
  search_params = self._build_params(
67
  db="pubmed",
68
- term=query,
69
  retmax=max_results,
70
  sort="relevance",
71
  )
 
7
  import xmltodict
8
  from tenacity import retry, stop_after_attempt, wait_exponential
9
 
10
+ from src.tools.query_utils import preprocess_query
11
  from src.utils.config import settings
12
  from src.utils.exceptions import RateLimitError, SearchError
13
  from src.utils.models import Citation, Evidence
 
62
  """
63
  await self._rate_limit()
64
 
65
+ # Preprocess query to remove noise and expand synonyms
66
+ clean_query = preprocess_query(query)
67
+ final_query = clean_query if clean_query else query
68
+
69
  async with httpx.AsyncClient(timeout=30.0) as client:
70
  # Step 1: Search for PMIDs
71
  search_params = self._build_params(
72
  db="pubmed",
73
+ term=final_query,
74
  retmax=max_results,
75
  sort="relevance",
76
  )
src/tools/query_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Query preprocessing utilities for biomedical search."""
2
+
3
+ import re
4
+
5
+ # Question words and filler words to remove
6
+ QUESTION_WORDS: set[str] = {
7
+ # Question starters
8
+ "what",
9
+ "which",
10
+ "how",
11
+ "why",
12
+ "when",
13
+ "where",
14
+ "who",
15
+ "whom",
16
+ # Auxiliary verbs in questions
17
+ "is",
18
+ "are",
19
+ "was",
20
+ "were",
21
+ "do",
22
+ "does",
23
+ "did",
24
+ "can",
25
+ "could",
26
+ "would",
27
+ "should",
28
+ "will",
29
+ "shall",
30
+ "may",
31
+ "might",
32
+ # Filler words in natural questions
33
+ "show",
34
+ "promise",
35
+ "help",
36
+ "believe",
37
+ "think",
38
+ "suggest",
39
+ "possible",
40
+ "potential",
41
+ "effective",
42
+ "useful",
43
+ "good",
44
+ # Articles (remove but less aggressively)
45
+ "the",
46
+ "a",
47
+ "an",
48
+ }
49
+
50
+ # Medical synonym expansions
51
+ SYNONYMS: dict[str, list[str]] = {
52
+ "long covid": [
53
+ "long COVID",
54
+ "PASC",
55
+ "post-acute sequelae of SARS-CoV-2",
56
+ "post-COVID syndrome",
57
+ "post-COVID-19 condition",
58
+ ],
59
+ "alzheimer": [
60
+ "Alzheimer's disease",
61
+ "Alzheimer disease",
62
+ "AD",
63
+ "Alzheimer dementia",
64
+ ],
65
+ "parkinson": [
66
+ "Parkinson's disease",
67
+ "Parkinson disease",
68
+ "PD",
69
+ ],
70
+ "diabetes": [
71
+ "diabetes mellitus",
72
+ "type 2 diabetes",
73
+ "T2DM",
74
+ "diabetic",
75
+ ],
76
+ "cancer": [
77
+ "cancer",
78
+ "neoplasm",
79
+ "tumor",
80
+ "malignancy",
81
+ "carcinoma",
82
+ ],
83
+ "heart disease": [
84
+ "cardiovascular disease",
85
+ "CVD",
86
+ "coronary artery disease",
87
+ "heart failure",
88
+ ],
89
+ }
90
+
91
+
92
+ def strip_question_words(query: str) -> str:
93
+ """
94
+ Remove question words and filler terms from query.
95
+
96
+ Args:
97
+ query: Raw query string
98
+
99
+ Returns:
100
+ Query with question words removed
101
+ """
102
+ words = query.lower().split()
103
+ filtered = [w for w in words if w not in QUESTION_WORDS]
104
+ return " ".join(filtered)
105
+
106
+
107
+ def expand_synonyms(query: str) -> str:
108
+ """
109
+ Expand medical terms to include synonyms.
110
+
111
+ Args:
112
+ query: Query string
113
+
114
+ Returns:
115
+ Query with synonym expansions in OR groups
116
+ """
117
+ result = query.lower()
118
+
119
+ for term, expansions in SYNONYMS.items():
120
+ if term in result:
121
+ # Create OR group: ("term1" OR "term2" OR "term3")
122
+ or_group = " OR ".join([f'"{exp}"' for exp in expansions])
123
+ # Case insensitive replacement is tricky with simple replace
124
+ # But we lowercased result already.
125
+ # However, this replaces ALL instances.
126
+ # Also, result is lowercased, so we lose original casing if any.
127
+ # But search engines are usually case-insensitive.
128
+ result = result.replace(term, f"({or_group})")
129
+
130
+ return result
131
+
132
+
133
+ def preprocess_query(raw_query: str) -> str:
134
+ """
135
+ Full preprocessing pipeline for PubMed queries.
136
+
137
+ Pipeline:
138
+ 1. Strip whitespace and punctuation
139
+ 2. Remove question words
140
+ 3. Expand medical synonyms
141
+
142
+ Args:
143
+ raw_query: Natural language query from user
144
+
145
+ Returns:
146
+ Optimized query for PubMed
147
+ """
148
+ if not raw_query or not raw_query.strip():
149
+ return ""
150
+
151
+ # Remove question marks and extra whitespace
152
+ query = raw_query.replace("?", "").strip()
153
+ query = re.sub(r"\s+", " ", query)
154
+
155
+ # Strip question words
156
+ query = strip_question_words(query)
157
+
158
+ # Expand synonyms
159
+ query = expand_synonyms(query)
160
+
161
+ return query.strip()
tests/unit/tools/test_pubmed.py CHANGED
@@ -97,3 +97,31 @@ class TestPubMedTool:
97
  assert len(results) == 1
98
  assert results[0].citation.source == "pubmed"
99
  assert "Smith John" in results[0].citation.authors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  assert len(results) == 1
98
  assert results[0].citation.source == "pubmed"
99
  assert "Smith John" in results[0].citation.authors
100
+
101
+ @pytest.mark.asyncio
102
+ async def test_search_preprocesses_query(self, mocker):
103
+ """Test that queries are preprocessed before search."""
104
+ mock_search_response = MagicMock()
105
+ mock_search_response.json.return_value = {"esearchresult": {"idlist": []}}
106
+ mock_search_response.raise_for_status = MagicMock()
107
+
108
+ mock_client = AsyncMock()
109
+ mock_client.get = AsyncMock(return_value=mock_search_response)
110
+ mock_client.__aenter__ = AsyncMock(return_value=mock_client)
111
+ mock_client.__aexit__ = AsyncMock(return_value=None)
112
+
113
+ mocker.patch("httpx.AsyncClient", return_value=mock_client)
114
+
115
+ tool = PubMedTool()
116
+ await tool.search("What drugs help with Long COVID?")
117
+
118
+ # Verify call args
119
+ call_args = mock_client.get.call_args
120
+ params = call_args[1]["params"]
121
+ term = params["term"]
122
+
123
+ # "what" and "help" should be stripped
124
+ assert "what" not in term.lower()
125
+ assert "help" not in term.lower()
126
+ # "long covid" should be expanded
127
+ assert "PASC" in term or "post-COVID" in term
tests/unit/tools/test_query_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for query preprocessing utilities."""
2
+
3
+ import pytest
4
+
5
+ from src.tools.query_utils import expand_synonyms, preprocess_query, strip_question_words
6
+
7
+
8
+ @pytest.mark.unit
9
+ class TestQueryPreprocessing:
10
+ """Tests for query preprocessing."""
11
+
12
+ def test_strip_question_words(self):
13
+ """Test removal of question words."""
14
+ assert strip_question_words("What drugs treat cancer") == "drugs treat cancer"
15
+ assert strip_question_words("Which medications help diabetes") == "medications diabetes"
16
+ assert strip_question_words("How can we cure alzheimer") == "we cure alzheimer"
17
+ assert strip_question_words("Is metformin effective") == "metformin"
18
+
19
+ def test_strip_preserves_medical_terms(self):
20
+ """Test that medical terms are preserved."""
21
+ result = strip_question_words("What is the mechanism of metformin")
22
+ assert "metformin" in result
23
+ assert "mechanism" in result
24
+
25
+ def test_expand_synonyms_long_covid(self):
26
+ """Test Long COVID synonym expansion."""
27
+ result = expand_synonyms("long covid treatment")
28
+ assert "PASC" in result or "post-COVID" in result
29
+
30
+ def test_expand_synonyms_alzheimer(self):
31
+ """Test Alzheimer's synonym expansion."""
32
+ result = expand_synonyms("alzheimer drug")
33
+ assert "Alzheimer" in result
34
+
35
+ def test_expand_synonyms_preserves_unknown(self):
36
+ """Test that unknown terms are preserved."""
37
+ result = expand_synonyms("metformin diabetes")
38
+ assert "metformin" in result
39
+ assert "diabetes" in result
40
+
41
+ def test_preprocess_query_full_pipeline(self):
42
+ """Test complete preprocessing pipeline."""
43
+ raw = "What medications show promise for Long COVID?"
44
+ result = preprocess_query(raw)
45
+
46
+ # Should not contain question words
47
+ assert "what" not in result.lower()
48
+ assert "show" not in result.lower()
49
+ assert "promise" not in result.lower()
50
+
51
+ # Should contain expanded terms
52
+ assert "PASC" in result or "post-COVID" in result or "long covid" in result.lower()
53
+ assert "medications" in result.lower() or "drug" in result.lower()
54
+
55
+ def test_preprocess_query_removes_punctuation(self):
56
+ """Test that question marks are removed."""
57
+ result = preprocess_query("Is metformin safe?")
58
+ assert "?" not in result
59
+
60
+ def test_preprocess_query_handles_empty(self):
61
+ """Test handling of empty/whitespace queries."""
62
+ assert preprocess_query("") == ""
63
+ assert preprocess_query(" ") == ""
64
+
65
+ def test_preprocess_query_already_clean(self):
66
+ """Test that clean queries pass through."""
67
+ clean = "metformin diabetes mechanism"
68
+ result = preprocess_query(clean)
69
+ assert "metformin" in result
70
+ assert "diabetes" in result
71
+ assert "mechanism" in result