VibecoderMcSwaggins commited on
Commit
3bacbf8
·
1 Parent(s): 24a5878

feat(phase6): Embeddings & Semantic Search implementation

Browse files

- Added EmbeddingService with ChromaDB and SentenceTransformers
- Updated SearchAgent to support semantic deduplication and search
- Updated MagenticOrchestrator to use embeddings if available
- Added 100% async-safe implementation with run_in_executor
- Added comprehensive unit tests

Dockerfile CHANGED
@@ -19,7 +19,14 @@ COPY src/ src/
19
  COPY README.md .
20
 
21
  # Install dependencies
22
- RUN uv sync --frozen --no-dev
 
 
 
 
 
 
 
23
 
24
  # Create non-root user
25
  RUN useradd --create-home --shell /bin/bash appuser
 
19
  COPY README.md .
20
 
21
  # Install dependencies
22
+ RUN uv sync --frozen --no-dev --all-extras
23
+
24
+ # Set cache directory for HuggingFace models
25
+ ENV HF_HOME=/app/.cache
26
+ ENV TRANSFORMERS_CACHE=/app/.cache
27
+
28
+ # Pre-download the embedding model during build to speed up cold starts
29
+ RUN uv run python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
30
 
31
  # Create non-root user
32
  RUN useradd --create-home --shell /bin/bash appuser
pyproject.toml CHANGED
@@ -49,6 +49,10 @@ dev = [
49
  magentic = [
50
  "agent-framework-core",
51
  ]
 
 
 
 
52
 
53
  [build-system]
54
  requires = ["hatchling"]
 
49
  magentic = [
50
  "agent-framework-core",
51
  ]
52
+ embeddings = [
53
+ "chromadb>=0.4.0",
54
+ "sentence-transformers>=2.2.0",
55
+ ]
56
 
57
  [build-system]
58
  requires = ["hatchling"]
src/agents/search_agent.py CHANGED
@@ -1,5 +1,5 @@
1
  from collections.abc import AsyncIterable
2
- from typing import Any
3
 
4
  from agent_framework import (
5
  AgentRunResponse,
@@ -11,7 +11,10 @@ from agent_framework import (
11
  )
12
 
13
  from src.orchestrator import SearchHandlerProtocol
14
- from src.utils.models import Evidence, SearchResult
 
 
 
15
 
16
 
17
  class SearchAgent(BaseAgent): # type: ignore[misc]
@@ -21,6 +24,7 @@ class SearchAgent(BaseAgent): # type: ignore[misc]
21
  self,
22
  search_handler: SearchHandlerProtocol,
23
  evidence_store: dict[str, list[Evidence]],
 
24
  ) -> None:
25
  super().__init__(
26
  name="SearchAgent",
@@ -28,6 +32,7 @@ class SearchAgent(BaseAgent): # type: ignore[misc]
28
  )
29
  self._handler = search_handler
30
  self._evidence_store = evidence_store
 
31
 
32
  async def run(
33
  self,
@@ -62,30 +67,92 @@ class SearchAgent(BaseAgent): # type: ignore[misc]
62
  result: SearchResult = await self._handler.execute(query, max_results_per_tool=10)
63
 
64
  # Update shared evidence store
65
- # We append new evidence, deduplicating by URL is handled in Orchestrator usually,
66
- # but here we should probably add to the list.
67
- # For simplicity in this MVP phase, we just extend the list.
68
- # Ideally, we should dedupe.
69
- existing_urls = {e.citation.url for e in self._evidence_store["current"]}
70
- new_unique = [e for e in result.evidence if e.citation.url not in existing_urls]
71
- self._evidence_store["current"].extend(new_unique)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Format response
 
 
 
 
 
 
 
 
 
74
  evidence_text = "\n".join(
75
  [
76
  f"- [{e.citation.title}]({e.citation.url}): {e.content[:200]}..."
77
- for e in result.evidence[:5]
78
  ]
79
  )
80
 
81
  response_text = (
82
- f"Found {result.total_found} sources ({len(new_unique)} new):\n\n{evidence_text}"
 
83
  )
84
 
85
  return AgentRunResponse(
86
  messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
87
  response_id=f"search-{result.total_found}",
88
- additional_properties={"evidence": [e.model_dump() for e in result.evidence]},
89
  )
90
 
91
  async def run_stream(
 
1
  from collections.abc import AsyncIterable
2
+ from typing import TYPE_CHECKING, Any
3
 
4
  from agent_framework import (
5
  AgentRunResponse,
 
11
  )
12
 
13
  from src.orchestrator import SearchHandlerProtocol
14
+ from src.utils.models import Citation, Evidence, SearchResult
15
+
16
+ if TYPE_CHECKING:
17
+ from src.services.embeddings import EmbeddingService
18
 
19
 
20
  class SearchAgent(BaseAgent): # type: ignore[misc]
 
24
  self,
25
  search_handler: SearchHandlerProtocol,
26
  evidence_store: dict[str, list[Evidence]],
27
+ embedding_service: "EmbeddingService | None" = None,
28
  ) -> None:
29
  super().__init__(
30
  name="SearchAgent",
 
32
  )
33
  self._handler = search_handler
34
  self._evidence_store = evidence_store
35
+ self._embeddings = embedding_service
36
 
37
  async def run(
38
  self,
 
67
  result: SearchResult = await self._handler.execute(query, max_results_per_tool=10)
68
 
69
  # Update shared evidence store
70
+ if self._embeddings:
71
+ # Deduplicate by semantic similarity (async-safe)
72
+ unique_evidence = await self._embeddings.deduplicate(result.evidence)
73
+
74
+ # Also search for semantically related evidence (async-safe)
75
+ related = await self._embeddings.search_similar(query, n_results=5)
76
+
77
+ # Merge related evidence not already in results
78
+ # We need to reconstruct Evidence objects from stored data
79
+ existing_urls = {e.citation.url for e in unique_evidence}
80
+
81
+ # Also check what's already in the global store to avoid re-adding
82
+ # logic here is a bit complex: deduplicate returned unique from *new search*
83
+ # but we also want related from *previous searches*
84
+
85
+ related_evidence = []
86
+ for item in related:
87
+ if item["id"] not in existing_urls:
88
+ # Create Evidence from stored metadata
89
+ # Check if metadata has required fields
90
+ meta = item.get("metadata", {})
91
+ # Fallback if date missing
92
+ date = meta.get("date") or "n.d."
93
+ authors = meta.get("authors") # Might be list or string depending on how stored
94
+ if isinstance(authors, str):
95
+ authors = [authors]
96
+ if not authors:
97
+ authors = ["Unknown"]
98
+
99
+ ev = Evidence(
100
+ content=item["content"],
101
+ citation=Citation(
102
+ title=meta.get("title", "Untitled"),
103
+ url=item["id"],
104
+ source=meta.get("source", "vector_db"),
105
+ date=date,
106
+ authors=authors,
107
+ ),
108
+ relevance=item.get("distance", 0.0), # Use distance/similarity as proxy
109
+ )
110
+ related_evidence.append(ev)
111
+
112
+ # Combine
113
+ final_new_evidence = unique_evidence + related_evidence
114
+
115
+ # Add to global store (deduping against global store)
116
+ global_urls = {e.citation.url for e in self._evidence_store["current"]}
117
+ really_new = [e for e in final_new_evidence if e.citation.url not in global_urls]
118
+ self._evidence_store["current"].extend(really_new)
119
+
120
+ # Update result for reporting
121
+ total_new = len(really_new)
122
+
123
+ else:
124
+ # Fallback to URL-based deduplication
125
+ existing_urls = {e.citation.url for e in self._evidence_store["current"]}
126
+ new_unique = [e for e in result.evidence if e.citation.url not in existing_urls]
127
+ self._evidence_store["current"].extend(new_unique)
128
+ total_new = len(new_unique)
129
 
130
  # Format response
131
+ # Get latest N items from store or just the new ones
132
+ # Let's show what was found in this run + related
133
+
134
+ evidence_to_show = (
135
+ (unique_evidence + related_evidence)
136
+ if self._embeddings and "unique_evidence" in locals()
137
+ else result.evidence
138
+ )
139
+
140
  evidence_text = "\n".join(
141
  [
142
  f"- [{e.citation.title}]({e.citation.url}): {e.content[:200]}..."
143
+ for e in evidence_to_show[:5]
144
  ]
145
  )
146
 
147
  response_text = (
148
+ f"Found {result.total_found} sources ({total_new} new added to context):\n\n"
149
+ f"{evidence_text}"
150
  )
151
 
152
  return AgentRunResponse(
153
  messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
154
  response_id=f"search-{result.total_found}",
155
+ additional_properties={"evidence": [e.model_dump() for e in evidence_to_show]},
156
  )
157
 
158
  async def run_stream(
src/orchestrator.py CHANGED
@@ -263,7 +263,7 @@ class Orchestrator:
263
 
264
  citations = "\n".join(
265
  [
266
- f"{i+1}. [{e.citation.title}]({e.citation.url}) "
267
  f"({e.citation.source.upper()}, {e.citation.date})"
268
  for i, e in enumerate(evidence[:10]) # Limit to 10 citations
269
  ]
@@ -312,7 +312,7 @@ class Orchestrator:
312
  """
313
  citations = "\n".join(
314
  [
315
- f"{i+1}. [{e.citation.title}]({e.citation.url}) ({e.citation.source.upper()})"
316
  for i, e in enumerate(evidence[:10])
317
  ]
318
  )
 
263
 
264
  citations = "\n".join(
265
  [
266
+ f"{i + 1}. [{e.citation.title}]({e.citation.url}) "
267
  f"({e.citation.source.upper()}, {e.citation.date})"
268
  for i, e in enumerate(evidence[:10]) # Limit to 10 citations
269
  ]
 
312
  """
313
  citations = "\n".join(
314
  [
315
+ f"{i + 1}. [{e.citation.title}]({e.citation.url}) ({e.citation.source.upper()})"
316
  for i, e in enumerate(evidence[:10])
317
  ]
318
  )
src/orchestrator_magentic.py CHANGED
@@ -54,8 +54,22 @@ class MagenticOrchestrator:
54
  iteration=0,
55
  )
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Create agent wrappers
58
- search_agent = SearchAgent(self._search_handler, self._evidence_store)
 
 
59
  judge_agent = JudgeAgent(self._judge_handler, self._evidence_store)
60
 
61
  # Build Magentic workflow
@@ -78,8 +92,17 @@ class MagenticOrchestrator:
78
  )
79
 
80
  # Task instruction for the manager
81
- task = f"""Research drug repurposing opportunities for: {query}
 
 
 
 
 
 
 
82
 
 
 
83
  Instructions:
84
  1. Use SearcherAgent to find evidence. SEND ONLY A SIMPLE KEYWORD QUERY (e.g. "metformin aging")
85
  as the instruction. Complex queries fail.
 
54
  iteration=0,
55
  )
56
 
57
+ # Initialize embedding service (optional)
58
+ embedding_service = None
59
+ try:
60
+ from src.services.embeddings import get_embedding_service
61
+
62
+ embedding_service = get_embedding_service()
63
+ logger.info("Embedding service enabled")
64
+ except ImportError:
65
+ logger.info("Embedding service not available (dependencies missing)")
66
+ except Exception as e:
67
+ logger.warning("Failed to initialize embedding service", error=str(e))
68
+
69
  # Create agent wrappers
70
+ search_agent = SearchAgent(
71
+ self._search_handler, self._evidence_store, embedding_service=embedding_service
72
+ )
73
  judge_agent = JudgeAgent(self._judge_handler, self._evidence_store)
74
 
75
  # Build Magentic workflow
 
92
  )
93
 
94
  # Task instruction for the manager
95
+ semantic_note = ""
96
+ if embedding_service:
97
+ semantic_note = """
98
+ The system has semantic search enabled. When evidence is found:
99
+ 1. Related concepts will be automatically surfaced
100
+ 2. Duplicates are removed by meaning, not just URL
101
+ 3. Use the surfaced related concepts to refine searches
102
+ """
103
 
104
+ task = f"""Research drug repurposing opportunities for: {query}
105
+ {semantic_note}
106
  Instructions:
107
  1. Use SearcherAgent to find evidence. SEND ONLY A SIMPLE KEYWORD QUERY (e.g. "metformin aging")
108
  as the instruction. Complex queries fail.
src/services/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Services for DeepCritical."""
src/services/embeddings.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Embedding service for semantic search.
2
+
3
+ IMPORTANT: All public methods are async to avoid blocking the event loop.
4
+ The sentence-transformers model is CPU-bound, so we use run_in_executor().
5
+ """
6
+
7
+ import asyncio
8
+ from typing import Any
9
+
10
+ import chromadb
11
+ from sentence_transformers import SentenceTransformer
12
+
13
+ from src.utils.models import Evidence
14
+
15
+
16
+ class EmbeddingService:
17
+ """Handles text embedding and vector storage.
18
+
19
+ All embedding operations run in a thread pool to avoid blocking
20
+ the async event loop. See src/tools/websearch.py for the pattern.
21
+ """
22
+
23
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
24
+ self._model = SentenceTransformer(model_name)
25
+ self._client = chromadb.Client() # In-memory for hackathon
26
+ self._collection = self._client.create_collection(
27
+ name="evidence", metadata={"hnsw:space": "cosine"}
28
+ )
29
+
30
+ # ─────────────────────────────────────────────────────────────────
31
+ # Sync internal methods (run in thread pool)
32
+ # ─────────────────────────────────────────────────────────────────
33
+
34
+ def _sync_embed(self, text: str) -> list[float]:
35
+ """Synchronous embedding - DO NOT call directly from async code."""
36
+ return list(self._model.encode(text).tolist())
37
+
38
+ def _sync_batch_embed(self, texts: list[str]) -> list[list[float]]:
39
+ """Batch embedding for efficiency - DO NOT call directly from async code."""
40
+ return [list(e.tolist()) for e in self._model.encode(texts)]
41
+
42
+ # ─────────────────────────────────────────────────────────────────
43
+ # Async public methods (safe for event loop)
44
+ # ─────────────────────────────────────────────────────────────────
45
+
46
+ async def embed(self, text: str) -> list[float]:
47
+ """Embed a single text (async-safe).
48
+
49
+ Uses run_in_executor to avoid blocking the event loop.
50
+ """
51
+ loop = asyncio.get_running_loop()
52
+ return await loop.run_in_executor(None, self._sync_embed, text)
53
+
54
+ async def embed_batch(self, texts: list[str]) -> list[list[float]]:
55
+ """Batch embed multiple texts (async-safe, more efficient)."""
56
+ loop = asyncio.get_running_loop()
57
+ return await loop.run_in_executor(None, self._sync_batch_embed, texts)
58
+
59
+ async def add_evidence(self, evidence_id: str, content: str, metadata: dict[str, Any]) -> None:
60
+ """Add evidence to vector store (async-safe)."""
61
+ embedding = await self.embed(content)
62
+ # ChromaDB operations are fast, but wrap for consistency
63
+ loop = asyncio.get_running_loop()
64
+ await loop.run_in_executor(
65
+ None,
66
+ lambda: self._collection.add(
67
+ ids=[evidence_id],
68
+ embeddings=[embedding], # type: ignore[arg-type]
69
+ metadatas=[metadata],
70
+ documents=[content],
71
+ ),
72
+ )
73
+
74
+ async def search_similar(self, query: str, n_results: int = 5) -> list[dict[str, Any]]:
75
+ """Find semantically similar evidence (async-safe)."""
76
+ query_embedding = await self.embed(query)
77
+
78
+ loop = asyncio.get_running_loop()
79
+ results = await loop.run_in_executor(
80
+ None,
81
+ lambda: self._collection.query(
82
+ query_embeddings=[query_embedding], # type: ignore[arg-type]
83
+ n_results=n_results,
84
+ ),
85
+ )
86
+
87
+ # Handle empty results gracefully
88
+ ids = results.get("ids")
89
+ docs = results.get("documents")
90
+ metas = results.get("metadatas")
91
+ dists = results.get("distances")
92
+
93
+ if not ids or not ids[0] or not docs or not metas or not dists:
94
+ return []
95
+
96
+ return [
97
+ {"id": id, "content": doc, "metadata": meta, "distance": dist}
98
+ for id, doc, meta, dist in zip(
99
+ ids[0],
100
+ docs[0],
101
+ metas[0],
102
+ dists[0],
103
+ strict=False,
104
+ )
105
+ ]
106
+
107
+ async def deduplicate(
108
+ self, new_evidence: list[Evidence], threshold: float = 0.9
109
+ ) -> list[Evidence]:
110
+ """Remove semantically duplicate evidence (async-safe)."""
111
+ unique = []
112
+ for evidence in new_evidence:
113
+ similar = await self.search_similar(evidence.content, n_results=1)
114
+ if not similar or similar[0]["distance"] > (1 - threshold):
115
+ unique.append(evidence)
116
+ await self.add_evidence(
117
+ evidence_id=evidence.citation.url,
118
+ content=evidence.content,
119
+ metadata={"source": evidence.citation.source},
120
+ )
121
+ return unique
122
+
123
+
124
+ _embedding_service: EmbeddingService | None = None
125
+
126
+
127
+ def get_embedding_service() -> EmbeddingService:
128
+ """Get singleton instance of EmbeddingService."""
129
+ global _embedding_service # noqa: PLW0603
130
+ if _embedding_service is None:
131
+ _embedding_service = EmbeddingService()
132
+ return _embedding_service
tests/unit/agents/test_search_agent.py CHANGED
@@ -81,5 +81,48 @@ async def test_run_handles_list_input(mock_handler: AsyncMock) -> None:
81
  ChatMessage(role=Role.USER, text="test query"),
82
  ]
83
  await agent.run(messages)
84
-
85
  mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  ChatMessage(role=Role.USER, text="test query"),
82
  ]
83
  await agent.run(messages)
 
84
  mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)
85
+
86
+
87
+ @pytest.mark.asyncio
88
+ async def test_run_uses_embeddings(mock_handler: AsyncMock) -> None:
89
+ """Test that run uses embedding service if provided."""
90
+ store: dict = {"current": []}
91
+
92
+ # Mock embedding service
93
+ mock_embeddings = AsyncMock()
94
+ # Mock deduplicate to return the evidence as is (or filtered)
95
+ mock_embeddings.deduplicate.return_value = [
96
+ Evidence(
97
+ content="unique content",
98
+ citation=Citation(source="pubmed", url="u1", title="t1", date="2024"),
99
+ )
100
+ ]
101
+ # Mock search_similar to return related items
102
+ mock_embeddings.search_similar.return_value = [
103
+ {
104
+ "id": "u2",
105
+ "content": "related content",
106
+ "metadata": {"source": "web", "title": "related", "date": "2024"},
107
+ "distance": 0.1,
108
+ }
109
+ ]
110
+
111
+ agent = SearchAgent(mock_handler, store, embedding_service=mock_embeddings)
112
+
113
+ await agent.run("test query")
114
+
115
+ # Verify deduplicate called
116
+ mock_embeddings.deduplicate.assert_awaited_once()
117
+
118
+ # Verify semantic search called
119
+ mock_embeddings.search_similar.assert_awaited_once_with("test query", n_results=5)
120
+
121
+ # Verify store contains related evidence (if logic implemented to add it)
122
+ # Note: logic for adding related evidence needs to be implemented in SearchAgent
123
+ # The spec says: "Merge related evidence not already in results"
124
+
125
+ # Check if u1 (deduplicated result) is in store
126
+ assert any(e.citation.url == "u1" for e in store["current"])
127
+ # Check if u2 (related result) is in store
128
+ assert any(e.citation.url == "u2" for e in store["current"])
tests/unit/services/test_embeddings.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for EmbeddingService."""
2
+
3
+ from unittest.mock import patch
4
+
5
+ import numpy as np
6
+ import pytest
7
+
8
+ # Skip if embeddings dependencies are not installed
9
+ pytest.importorskip("chromadb")
10
+ pytest.importorskip("sentence_transformers")
11
+
12
+ from src.services.embeddings import EmbeddingService
13
+
14
+
15
+ class TestEmbeddingService:
16
+ @pytest.fixture
17
+ def mock_sentence_transformer(self):
18
+ with patch("src.services.embeddings.SentenceTransformer") as mock_st_class:
19
+ mock_model = mock_st_class.return_value
20
+ # Mock encode to return a numpy array
21
+ mock_model.encode.return_value = np.array([0.1, 0.2, 0.3])
22
+ yield mock_model
23
+
24
+ @pytest.fixture
25
+ def mock_chroma_client(self):
26
+ with patch("src.services.embeddings.chromadb.Client") as mock_client_class:
27
+ mock_client = mock_client_class.return_value
28
+ mock_collection = mock_client.create_collection.return_value
29
+ # Mock query return structure
30
+ mock_collection.query.return_value = {
31
+ "ids": [["id1"]],
32
+ "documents": [["doc1"]],
33
+ "metadatas": [[{"source": "pubmed"}]],
34
+ "distances": [[0.1]],
35
+ }
36
+ yield mock_client
37
+
38
+ @pytest.mark.asyncio
39
+ async def test_embed_returns_vector(self, mock_sentence_transformer, mock_chroma_client):
40
+ """Embedding should return a float vector (async check)."""
41
+ service = EmbeddingService()
42
+ embedding = await service.embed("metformin diabetes")
43
+
44
+ assert isinstance(embedding, list)
45
+ assert len(embedding) == 3 # noqa: PLR2004
46
+ assert all(isinstance(x, float) for x in embedding)
47
+ # Ensure it ran in executor (mock encode called)
48
+ mock_sentence_transformer.encode.assert_called_once()
49
+
50
+ @pytest.mark.asyncio
51
+ async def test_batch_embed_efficient(self, mock_sentence_transformer, mock_chroma_client):
52
+ """Batch embedding should call encode with list."""
53
+ # Setup mock for batch return (list of arrays)
54
+ import numpy as np
55
+
56
+ mock_sentence_transformer.encode.return_value = np.array([[0.1, 0.2], [0.3, 0.4]])
57
+
58
+ service = EmbeddingService()
59
+ texts = ["text one", "text two"]
60
+
61
+ batch_results = await service.embed_batch(texts)
62
+
63
+ assert len(batch_results) == 2 # noqa: PLR2004
64
+ assert isinstance(batch_results[0], list)
65
+ mock_sentence_transformer.encode.assert_called_with(texts)
66
+
67
+ @pytest.mark.asyncio
68
+ async def test_add_and_search(self, mock_sentence_transformer, mock_chroma_client):
69
+ """Should be able to add evidence and search for similar."""
70
+ service = EmbeddingService()
71
+ await service.add_evidence(
72
+ evidence_id="test1",
73
+ content="Metformin activates AMPK pathway",
74
+ metadata={"source": "pubmed"},
75
+ )
76
+
77
+ # Verify add was called
78
+ mock_collection = mock_chroma_client.create_collection.return_value
79
+ mock_collection.add.assert_called_once()
80
+
81
+ results = await service.search_similar("AMPK activation drugs", n_results=1)
82
+
83
+ # Verify query was called
84
+ mock_collection.query.assert_called_once()
85
+ assert len(results) == 1
86
+ assert results[0]["id"] == "id1"
87
+
88
+ @pytest.mark.asyncio
89
+ async def test_search_similar_empty_collection(
90
+ self, mock_sentence_transformer, mock_chroma_client
91
+ ):
92
+ """Search on empty collection should return empty list, not error."""
93
+ mock_collection = mock_chroma_client.create_collection.return_value
94
+ mock_collection.query.return_value = {
95
+ "ids": [[]],
96
+ "documents": [[]],
97
+ "metadatas": [[]],
98
+ "distances": [[]],
99
+ }
100
+
101
+ service = EmbeddingService()
102
+ results = await service.search_similar("anything", n_results=5)
103
+ assert results == []
104
+
105
+ @pytest.mark.asyncio
106
+ async def test_deduplicate(self, mock_sentence_transformer, mock_chroma_client):
107
+ """Deduplicate should remove similar items."""
108
+ from src.utils.models import Citation, Evidence
109
+
110
+ service = EmbeddingService()
111
+
112
+ # Mock search to return a match for the first item (duplicate)
113
+ # and no match for the second (unique)
114
+ mock_collection = mock_chroma_client.create_collection.return_value
115
+
116
+ # First call returns match (distance 0.05 < threshold)
117
+ # Second call returns no match or high distance
118
+ mock_collection.query.side_effect = [
119
+ {
120
+ "ids": [["existing_id"]],
121
+ "documents": [["doc"]],
122
+ "metadatas": [[{}]],
123
+ "distances": [[0.05]], # Very similar
124
+ },
125
+ {
126
+ "ids": [[]], # No match
127
+ "documents": [[]],
128
+ "metadatas": [[]],
129
+ "distances": [[]],
130
+ },
131
+ ]
132
+
133
+ evidence = [
134
+ Evidence(
135
+ content="Duplicate content",
136
+ citation=Citation(source="web", url="u1", title="t1", date="2024"),
137
+ ),
138
+ Evidence(
139
+ content="Unique content",
140
+ citation=Citation(source="web", url="u2", title="t2", date="2024"),
141
+ ),
142
+ ]
143
+
144
+ unique = await service.deduplicate(evidence, threshold=0.9)
145
+
146
+ # Only the unique one should remain
147
+ assert len(unique) == 1
148
+ assert unique[0].citation.url == "u2"
uv.lock CHANGED
The diff for this file is too large to render. See raw diff