Spaces:
Running
Running
Commit
·
3bacbf8
1
Parent(s):
24a5878
feat(phase6): Embeddings & Semantic Search implementation
Browse files- Added EmbeddingService with ChromaDB and SentenceTransformers
- Updated SearchAgent to support semantic deduplication and search
- Updated MagenticOrchestrator to use embeddings if available
- Added 100% async-safe implementation with run_in_executor
- Added comprehensive unit tests
- Dockerfile +8 -1
- pyproject.toml +4 -0
- src/agents/search_agent.py +79 -12
- src/orchestrator.py +2 -2
- src/orchestrator_magentic.py +25 -2
- src/services/__init__.py +1 -0
- src/services/embeddings.py +132 -0
- tests/unit/agents/test_search_agent.py +44 -1
- tests/unit/services/test_embeddings.py +148 -0
- uv.lock +0 -0
Dockerfile
CHANGED
|
@@ -19,7 +19,14 @@ COPY src/ src/
|
|
| 19 |
COPY README.md .
|
| 20 |
|
| 21 |
# Install dependencies
|
| 22 |
-
RUN uv sync --frozen --no-dev
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
# Create non-root user
|
| 25 |
RUN useradd --create-home --shell /bin/bash appuser
|
|
|
|
| 19 |
COPY README.md .
|
| 20 |
|
| 21 |
# Install dependencies
|
| 22 |
+
RUN uv sync --frozen --no-dev --all-extras
|
| 23 |
+
|
| 24 |
+
# Set cache directory for HuggingFace models
|
| 25 |
+
ENV HF_HOME=/app/.cache
|
| 26 |
+
ENV TRANSFORMERS_CACHE=/app/.cache
|
| 27 |
+
|
| 28 |
+
# Pre-download the embedding model during build to speed up cold starts
|
| 29 |
+
RUN uv run python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
|
| 30 |
|
| 31 |
# Create non-root user
|
| 32 |
RUN useradd --create-home --shell /bin/bash appuser
|
pyproject.toml
CHANGED
|
@@ -49,6 +49,10 @@ dev = [
|
|
| 49 |
magentic = [
|
| 50 |
"agent-framework-core",
|
| 51 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
[build-system]
|
| 54 |
requires = ["hatchling"]
|
|
|
|
| 49 |
magentic = [
|
| 50 |
"agent-framework-core",
|
| 51 |
]
|
| 52 |
+
embeddings = [
|
| 53 |
+
"chromadb>=0.4.0",
|
| 54 |
+
"sentence-transformers>=2.2.0",
|
| 55 |
+
]
|
| 56 |
|
| 57 |
[build-system]
|
| 58 |
requires = ["hatchling"]
|
src/agents/search_agent.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from collections.abc import AsyncIterable
|
| 2 |
-
from typing import Any
|
| 3 |
|
| 4 |
from agent_framework import (
|
| 5 |
AgentRunResponse,
|
|
@@ -11,7 +11,10 @@ from agent_framework import (
|
|
| 11 |
)
|
| 12 |
|
| 13 |
from src.orchestrator import SearchHandlerProtocol
|
| 14 |
-
from src.utils.models import Evidence, SearchResult
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class SearchAgent(BaseAgent): # type: ignore[misc]
|
|
@@ -21,6 +24,7 @@ class SearchAgent(BaseAgent): # type: ignore[misc]
|
|
| 21 |
self,
|
| 22 |
search_handler: SearchHandlerProtocol,
|
| 23 |
evidence_store: dict[str, list[Evidence]],
|
|
|
|
| 24 |
) -> None:
|
| 25 |
super().__init__(
|
| 26 |
name="SearchAgent",
|
|
@@ -28,6 +32,7 @@ class SearchAgent(BaseAgent): # type: ignore[misc]
|
|
| 28 |
)
|
| 29 |
self._handler = search_handler
|
| 30 |
self._evidence_store = evidence_store
|
|
|
|
| 31 |
|
| 32 |
async def run(
|
| 33 |
self,
|
|
@@ -62,30 +67,92 @@ class SearchAgent(BaseAgent): # type: ignore[misc]
|
|
| 62 |
result: SearchResult = await self._handler.execute(query, max_results_per_tool=10)
|
| 63 |
|
| 64 |
# Update shared evidence store
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# Format response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
evidence_text = "\n".join(
|
| 75 |
[
|
| 76 |
f"- [{e.citation.title}]({e.citation.url}): {e.content[:200]}..."
|
| 77 |
-
for e in
|
| 78 |
]
|
| 79 |
)
|
| 80 |
|
| 81 |
response_text = (
|
| 82 |
-
f"Found {result.total_found} sources ({
|
|
|
|
| 83 |
)
|
| 84 |
|
| 85 |
return AgentRunResponse(
|
| 86 |
messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
|
| 87 |
response_id=f"search-{result.total_found}",
|
| 88 |
-
additional_properties={"evidence": [e.model_dump() for e in
|
| 89 |
)
|
| 90 |
|
| 91 |
async def run_stream(
|
|
|
|
| 1 |
from collections.abc import AsyncIterable
|
| 2 |
+
from typing import TYPE_CHECKING, Any
|
| 3 |
|
| 4 |
from agent_framework import (
|
| 5 |
AgentRunResponse,
|
|
|
|
| 11 |
)
|
| 12 |
|
| 13 |
from src.orchestrator import SearchHandlerProtocol
|
| 14 |
+
from src.utils.models import Citation, Evidence, SearchResult
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from src.services.embeddings import EmbeddingService
|
| 18 |
|
| 19 |
|
| 20 |
class SearchAgent(BaseAgent): # type: ignore[misc]
|
|
|
|
| 24 |
self,
|
| 25 |
search_handler: SearchHandlerProtocol,
|
| 26 |
evidence_store: dict[str, list[Evidence]],
|
| 27 |
+
embedding_service: "EmbeddingService | None" = None,
|
| 28 |
) -> None:
|
| 29 |
super().__init__(
|
| 30 |
name="SearchAgent",
|
|
|
|
| 32 |
)
|
| 33 |
self._handler = search_handler
|
| 34 |
self._evidence_store = evidence_store
|
| 35 |
+
self._embeddings = embedding_service
|
| 36 |
|
| 37 |
async def run(
|
| 38 |
self,
|
|
|
|
| 67 |
result: SearchResult = await self._handler.execute(query, max_results_per_tool=10)
|
| 68 |
|
| 69 |
# Update shared evidence store
|
| 70 |
+
if self._embeddings:
|
| 71 |
+
# Deduplicate by semantic similarity (async-safe)
|
| 72 |
+
unique_evidence = await self._embeddings.deduplicate(result.evidence)
|
| 73 |
+
|
| 74 |
+
# Also search for semantically related evidence (async-safe)
|
| 75 |
+
related = await self._embeddings.search_similar(query, n_results=5)
|
| 76 |
+
|
| 77 |
+
# Merge related evidence not already in results
|
| 78 |
+
# We need to reconstruct Evidence objects from stored data
|
| 79 |
+
existing_urls = {e.citation.url for e in unique_evidence}
|
| 80 |
+
|
| 81 |
+
# Also check what's already in the global store to avoid re-adding
|
| 82 |
+
# logic here is a bit complex: deduplicate returned unique from *new search*
|
| 83 |
+
# but we also want related from *previous searches*
|
| 84 |
+
|
| 85 |
+
related_evidence = []
|
| 86 |
+
for item in related:
|
| 87 |
+
if item["id"] not in existing_urls:
|
| 88 |
+
# Create Evidence from stored metadata
|
| 89 |
+
# Check if metadata has required fields
|
| 90 |
+
meta = item.get("metadata", {})
|
| 91 |
+
# Fallback if date missing
|
| 92 |
+
date = meta.get("date") or "n.d."
|
| 93 |
+
authors = meta.get("authors") # Might be list or string depending on how stored
|
| 94 |
+
if isinstance(authors, str):
|
| 95 |
+
authors = [authors]
|
| 96 |
+
if not authors:
|
| 97 |
+
authors = ["Unknown"]
|
| 98 |
+
|
| 99 |
+
ev = Evidence(
|
| 100 |
+
content=item["content"],
|
| 101 |
+
citation=Citation(
|
| 102 |
+
title=meta.get("title", "Untitled"),
|
| 103 |
+
url=item["id"],
|
| 104 |
+
source=meta.get("source", "vector_db"),
|
| 105 |
+
date=date,
|
| 106 |
+
authors=authors,
|
| 107 |
+
),
|
| 108 |
+
relevance=item.get("distance", 0.0), # Use distance/similarity as proxy
|
| 109 |
+
)
|
| 110 |
+
related_evidence.append(ev)
|
| 111 |
+
|
| 112 |
+
# Combine
|
| 113 |
+
final_new_evidence = unique_evidence + related_evidence
|
| 114 |
+
|
| 115 |
+
# Add to global store (deduping against global store)
|
| 116 |
+
global_urls = {e.citation.url for e in self._evidence_store["current"]}
|
| 117 |
+
really_new = [e for e in final_new_evidence if e.citation.url not in global_urls]
|
| 118 |
+
self._evidence_store["current"].extend(really_new)
|
| 119 |
+
|
| 120 |
+
# Update result for reporting
|
| 121 |
+
total_new = len(really_new)
|
| 122 |
+
|
| 123 |
+
else:
|
| 124 |
+
# Fallback to URL-based deduplication
|
| 125 |
+
existing_urls = {e.citation.url for e in self._evidence_store["current"]}
|
| 126 |
+
new_unique = [e for e in result.evidence if e.citation.url not in existing_urls]
|
| 127 |
+
self._evidence_store["current"].extend(new_unique)
|
| 128 |
+
total_new = len(new_unique)
|
| 129 |
|
| 130 |
# Format response
|
| 131 |
+
# Get latest N items from store or just the new ones
|
| 132 |
+
# Let's show what was found in this run + related
|
| 133 |
+
|
| 134 |
+
evidence_to_show = (
|
| 135 |
+
(unique_evidence + related_evidence)
|
| 136 |
+
if self._embeddings and "unique_evidence" in locals()
|
| 137 |
+
else result.evidence
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
evidence_text = "\n".join(
|
| 141 |
[
|
| 142 |
f"- [{e.citation.title}]({e.citation.url}): {e.content[:200]}..."
|
| 143 |
+
for e in evidence_to_show[:5]
|
| 144 |
]
|
| 145 |
)
|
| 146 |
|
| 147 |
response_text = (
|
| 148 |
+
f"Found {result.total_found} sources ({total_new} new added to context):\n\n"
|
| 149 |
+
f"{evidence_text}"
|
| 150 |
)
|
| 151 |
|
| 152 |
return AgentRunResponse(
|
| 153 |
messages=[ChatMessage(role=Role.ASSISTANT, text=response_text)],
|
| 154 |
response_id=f"search-{result.total_found}",
|
| 155 |
+
additional_properties={"evidence": [e.model_dump() for e in evidence_to_show]},
|
| 156 |
)
|
| 157 |
|
| 158 |
async def run_stream(
|
src/orchestrator.py
CHANGED
|
@@ -263,7 +263,7 @@ class Orchestrator:
|
|
| 263 |
|
| 264 |
citations = "\n".join(
|
| 265 |
[
|
| 266 |
-
f"{i+1}. [{e.citation.title}]({e.citation.url}) "
|
| 267 |
f"({e.citation.source.upper()}, {e.citation.date})"
|
| 268 |
for i, e in enumerate(evidence[:10]) # Limit to 10 citations
|
| 269 |
]
|
|
@@ -312,7 +312,7 @@ class Orchestrator:
|
|
| 312 |
"""
|
| 313 |
citations = "\n".join(
|
| 314 |
[
|
| 315 |
-
f"{i+1}. [{e.citation.title}]({e.citation.url}) ({e.citation.source.upper()})"
|
| 316 |
for i, e in enumerate(evidence[:10])
|
| 317 |
]
|
| 318 |
)
|
|
|
|
| 263 |
|
| 264 |
citations = "\n".join(
|
| 265 |
[
|
| 266 |
+
f"{i + 1}. [{e.citation.title}]({e.citation.url}) "
|
| 267 |
f"({e.citation.source.upper()}, {e.citation.date})"
|
| 268 |
for i, e in enumerate(evidence[:10]) # Limit to 10 citations
|
| 269 |
]
|
|
|
|
| 312 |
"""
|
| 313 |
citations = "\n".join(
|
| 314 |
[
|
| 315 |
+
f"{i + 1}. [{e.citation.title}]({e.citation.url}) ({e.citation.source.upper()})"
|
| 316 |
for i, e in enumerate(evidence[:10])
|
| 317 |
]
|
| 318 |
)
|
src/orchestrator_magentic.py
CHANGED
|
@@ -54,8 +54,22 @@ class MagenticOrchestrator:
|
|
| 54 |
iteration=0,
|
| 55 |
)
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# Create agent wrappers
|
| 58 |
-
search_agent = SearchAgent(
|
|
|
|
|
|
|
| 59 |
judge_agent = JudgeAgent(self._judge_handler, self._evidence_store)
|
| 60 |
|
| 61 |
# Build Magentic workflow
|
|
@@ -78,8 +92,17 @@ class MagenticOrchestrator:
|
|
| 78 |
)
|
| 79 |
|
| 80 |
# Task instruction for the manager
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
|
|
|
|
|
|
| 83 |
Instructions:
|
| 84 |
1. Use SearcherAgent to find evidence. SEND ONLY A SIMPLE KEYWORD QUERY (e.g. "metformin aging")
|
| 85 |
as the instruction. Complex queries fail.
|
|
|
|
| 54 |
iteration=0,
|
| 55 |
)
|
| 56 |
|
| 57 |
+
# Initialize embedding service (optional)
|
| 58 |
+
embedding_service = None
|
| 59 |
+
try:
|
| 60 |
+
from src.services.embeddings import get_embedding_service
|
| 61 |
+
|
| 62 |
+
embedding_service = get_embedding_service()
|
| 63 |
+
logger.info("Embedding service enabled")
|
| 64 |
+
except ImportError:
|
| 65 |
+
logger.info("Embedding service not available (dependencies missing)")
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.warning("Failed to initialize embedding service", error=str(e))
|
| 68 |
+
|
| 69 |
# Create agent wrappers
|
| 70 |
+
search_agent = SearchAgent(
|
| 71 |
+
self._search_handler, self._evidence_store, embedding_service=embedding_service
|
| 72 |
+
)
|
| 73 |
judge_agent = JudgeAgent(self._judge_handler, self._evidence_store)
|
| 74 |
|
| 75 |
# Build Magentic workflow
|
|
|
|
| 92 |
)
|
| 93 |
|
| 94 |
# Task instruction for the manager
|
| 95 |
+
semantic_note = ""
|
| 96 |
+
if embedding_service:
|
| 97 |
+
semantic_note = """
|
| 98 |
+
The system has semantic search enabled. When evidence is found:
|
| 99 |
+
1. Related concepts will be automatically surfaced
|
| 100 |
+
2. Duplicates are removed by meaning, not just URL
|
| 101 |
+
3. Use the surfaced related concepts to refine searches
|
| 102 |
+
"""
|
| 103 |
|
| 104 |
+
task = f"""Research drug repurposing opportunities for: {query}
|
| 105 |
+
{semantic_note}
|
| 106 |
Instructions:
|
| 107 |
1. Use SearcherAgent to find evidence. SEND ONLY A SIMPLE KEYWORD QUERY (e.g. "metformin aging")
|
| 108 |
as the instruction. Complex queries fail.
|
src/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Services for DeepCritical."""
|
src/services/embeddings.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Embedding service for semantic search.
|
| 2 |
+
|
| 3 |
+
IMPORTANT: All public methods are async to avoid blocking the event loop.
|
| 4 |
+
The sentence-transformers model is CPU-bound, so we use run_in_executor().
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import asyncio
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import chromadb
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
+
|
| 13 |
+
from src.utils.models import Evidence
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class EmbeddingService:
|
| 17 |
+
"""Handles text embedding and vector storage.
|
| 18 |
+
|
| 19 |
+
All embedding operations run in a thread pool to avoid blocking
|
| 20 |
+
the async event loop. See src/tools/websearch.py for the pattern.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
| 24 |
+
self._model = SentenceTransformer(model_name)
|
| 25 |
+
self._client = chromadb.Client() # In-memory for hackathon
|
| 26 |
+
self._collection = self._client.create_collection(
|
| 27 |
+
name="evidence", metadata={"hnsw:space": "cosine"}
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# ─────────────────────────────────────────────────────────────────
|
| 31 |
+
# Sync internal methods (run in thread pool)
|
| 32 |
+
# ─────────────────────────────────────────────────────────────────
|
| 33 |
+
|
| 34 |
+
def _sync_embed(self, text: str) -> list[float]:
|
| 35 |
+
"""Synchronous embedding - DO NOT call directly from async code."""
|
| 36 |
+
return list(self._model.encode(text).tolist())
|
| 37 |
+
|
| 38 |
+
def _sync_batch_embed(self, texts: list[str]) -> list[list[float]]:
|
| 39 |
+
"""Batch embedding for efficiency - DO NOT call directly from async code."""
|
| 40 |
+
return [list(e.tolist()) for e in self._model.encode(texts)]
|
| 41 |
+
|
| 42 |
+
# ─────────────────────────────────────────────────────────────────
|
| 43 |
+
# Async public methods (safe for event loop)
|
| 44 |
+
# ─────────────────────────────────────────────────────────────────
|
| 45 |
+
|
| 46 |
+
async def embed(self, text: str) -> list[float]:
|
| 47 |
+
"""Embed a single text (async-safe).
|
| 48 |
+
|
| 49 |
+
Uses run_in_executor to avoid blocking the event loop.
|
| 50 |
+
"""
|
| 51 |
+
loop = asyncio.get_running_loop()
|
| 52 |
+
return await loop.run_in_executor(None, self._sync_embed, text)
|
| 53 |
+
|
| 54 |
+
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
| 55 |
+
"""Batch embed multiple texts (async-safe, more efficient)."""
|
| 56 |
+
loop = asyncio.get_running_loop()
|
| 57 |
+
return await loop.run_in_executor(None, self._sync_batch_embed, texts)
|
| 58 |
+
|
| 59 |
+
async def add_evidence(self, evidence_id: str, content: str, metadata: dict[str, Any]) -> None:
|
| 60 |
+
"""Add evidence to vector store (async-safe)."""
|
| 61 |
+
embedding = await self.embed(content)
|
| 62 |
+
# ChromaDB operations are fast, but wrap for consistency
|
| 63 |
+
loop = asyncio.get_running_loop()
|
| 64 |
+
await loop.run_in_executor(
|
| 65 |
+
None,
|
| 66 |
+
lambda: self._collection.add(
|
| 67 |
+
ids=[evidence_id],
|
| 68 |
+
embeddings=[embedding], # type: ignore[arg-type]
|
| 69 |
+
metadatas=[metadata],
|
| 70 |
+
documents=[content],
|
| 71 |
+
),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
async def search_similar(self, query: str, n_results: int = 5) -> list[dict[str, Any]]:
|
| 75 |
+
"""Find semantically similar evidence (async-safe)."""
|
| 76 |
+
query_embedding = await self.embed(query)
|
| 77 |
+
|
| 78 |
+
loop = asyncio.get_running_loop()
|
| 79 |
+
results = await loop.run_in_executor(
|
| 80 |
+
None,
|
| 81 |
+
lambda: self._collection.query(
|
| 82 |
+
query_embeddings=[query_embedding], # type: ignore[arg-type]
|
| 83 |
+
n_results=n_results,
|
| 84 |
+
),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Handle empty results gracefully
|
| 88 |
+
ids = results.get("ids")
|
| 89 |
+
docs = results.get("documents")
|
| 90 |
+
metas = results.get("metadatas")
|
| 91 |
+
dists = results.get("distances")
|
| 92 |
+
|
| 93 |
+
if not ids or not ids[0] or not docs or not metas or not dists:
|
| 94 |
+
return []
|
| 95 |
+
|
| 96 |
+
return [
|
| 97 |
+
{"id": id, "content": doc, "metadata": meta, "distance": dist}
|
| 98 |
+
for id, doc, meta, dist in zip(
|
| 99 |
+
ids[0],
|
| 100 |
+
docs[0],
|
| 101 |
+
metas[0],
|
| 102 |
+
dists[0],
|
| 103 |
+
strict=False,
|
| 104 |
+
)
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
async def deduplicate(
|
| 108 |
+
self, new_evidence: list[Evidence], threshold: float = 0.9
|
| 109 |
+
) -> list[Evidence]:
|
| 110 |
+
"""Remove semantically duplicate evidence (async-safe)."""
|
| 111 |
+
unique = []
|
| 112 |
+
for evidence in new_evidence:
|
| 113 |
+
similar = await self.search_similar(evidence.content, n_results=1)
|
| 114 |
+
if not similar or similar[0]["distance"] > (1 - threshold):
|
| 115 |
+
unique.append(evidence)
|
| 116 |
+
await self.add_evidence(
|
| 117 |
+
evidence_id=evidence.citation.url,
|
| 118 |
+
content=evidence.content,
|
| 119 |
+
metadata={"source": evidence.citation.source},
|
| 120 |
+
)
|
| 121 |
+
return unique
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
_embedding_service: EmbeddingService | None = None
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_embedding_service() -> EmbeddingService:
|
| 128 |
+
"""Get singleton instance of EmbeddingService."""
|
| 129 |
+
global _embedding_service # noqa: PLW0603
|
| 130 |
+
if _embedding_service is None:
|
| 131 |
+
_embedding_service = EmbeddingService()
|
| 132 |
+
return _embedding_service
|
tests/unit/agents/test_search_agent.py
CHANGED
|
@@ -81,5 +81,48 @@ async def test_run_handles_list_input(mock_handler: AsyncMock) -> None:
|
|
| 81 |
ChatMessage(role=Role.USER, text="test query"),
|
| 82 |
]
|
| 83 |
await agent.run(messages)
|
| 84 |
-
|
| 85 |
mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
ChatMessage(role=Role.USER, text="test query"),
|
| 82 |
]
|
| 83 |
await agent.run(messages)
|
|
|
|
| 84 |
mock_handler.execute.assert_awaited_once_with("test query", max_results_per_tool=10)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@pytest.mark.asyncio
|
| 88 |
+
async def test_run_uses_embeddings(mock_handler: AsyncMock) -> None:
|
| 89 |
+
"""Test that run uses embedding service if provided."""
|
| 90 |
+
store: dict = {"current": []}
|
| 91 |
+
|
| 92 |
+
# Mock embedding service
|
| 93 |
+
mock_embeddings = AsyncMock()
|
| 94 |
+
# Mock deduplicate to return the evidence as is (or filtered)
|
| 95 |
+
mock_embeddings.deduplicate.return_value = [
|
| 96 |
+
Evidence(
|
| 97 |
+
content="unique content",
|
| 98 |
+
citation=Citation(source="pubmed", url="u1", title="t1", date="2024"),
|
| 99 |
+
)
|
| 100 |
+
]
|
| 101 |
+
# Mock search_similar to return related items
|
| 102 |
+
mock_embeddings.search_similar.return_value = [
|
| 103 |
+
{
|
| 104 |
+
"id": "u2",
|
| 105 |
+
"content": "related content",
|
| 106 |
+
"metadata": {"source": "web", "title": "related", "date": "2024"},
|
| 107 |
+
"distance": 0.1,
|
| 108 |
+
}
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
agent = SearchAgent(mock_handler, store, embedding_service=mock_embeddings)
|
| 112 |
+
|
| 113 |
+
await agent.run("test query")
|
| 114 |
+
|
| 115 |
+
# Verify deduplicate called
|
| 116 |
+
mock_embeddings.deduplicate.assert_awaited_once()
|
| 117 |
+
|
| 118 |
+
# Verify semantic search called
|
| 119 |
+
mock_embeddings.search_similar.assert_awaited_once_with("test query", n_results=5)
|
| 120 |
+
|
| 121 |
+
# Verify store contains related evidence (if logic implemented to add it)
|
| 122 |
+
# Note: logic for adding related evidence needs to be implemented in SearchAgent
|
| 123 |
+
# The spec says: "Merge related evidence not already in results"
|
| 124 |
+
|
| 125 |
+
# Check if u1 (deduplicated result) is in store
|
| 126 |
+
assert any(e.citation.url == "u1" for e in store["current"])
|
| 127 |
+
# Check if u2 (related result) is in store
|
| 128 |
+
assert any(e.citation.url == "u2" for e in store["current"])
|
tests/unit/services/test_embeddings.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for EmbeddingService."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import patch
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
# Skip if embeddings dependencies are not installed
|
| 9 |
+
pytest.importorskip("chromadb")
|
| 10 |
+
pytest.importorskip("sentence_transformers")
|
| 11 |
+
|
| 12 |
+
from src.services.embeddings import EmbeddingService
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestEmbeddingService:
|
| 16 |
+
@pytest.fixture
|
| 17 |
+
def mock_sentence_transformer(self):
|
| 18 |
+
with patch("src.services.embeddings.SentenceTransformer") as mock_st_class:
|
| 19 |
+
mock_model = mock_st_class.return_value
|
| 20 |
+
# Mock encode to return a numpy array
|
| 21 |
+
mock_model.encode.return_value = np.array([0.1, 0.2, 0.3])
|
| 22 |
+
yield mock_model
|
| 23 |
+
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def mock_chroma_client(self):
|
| 26 |
+
with patch("src.services.embeddings.chromadb.Client") as mock_client_class:
|
| 27 |
+
mock_client = mock_client_class.return_value
|
| 28 |
+
mock_collection = mock_client.create_collection.return_value
|
| 29 |
+
# Mock query return structure
|
| 30 |
+
mock_collection.query.return_value = {
|
| 31 |
+
"ids": [["id1"]],
|
| 32 |
+
"documents": [["doc1"]],
|
| 33 |
+
"metadatas": [[{"source": "pubmed"}]],
|
| 34 |
+
"distances": [[0.1]],
|
| 35 |
+
}
|
| 36 |
+
yield mock_client
|
| 37 |
+
|
| 38 |
+
@pytest.mark.asyncio
|
| 39 |
+
async def test_embed_returns_vector(self, mock_sentence_transformer, mock_chroma_client):
|
| 40 |
+
"""Embedding should return a float vector (async check)."""
|
| 41 |
+
service = EmbeddingService()
|
| 42 |
+
embedding = await service.embed("metformin diabetes")
|
| 43 |
+
|
| 44 |
+
assert isinstance(embedding, list)
|
| 45 |
+
assert len(embedding) == 3 # noqa: PLR2004
|
| 46 |
+
assert all(isinstance(x, float) for x in embedding)
|
| 47 |
+
# Ensure it ran in executor (mock encode called)
|
| 48 |
+
mock_sentence_transformer.encode.assert_called_once()
|
| 49 |
+
|
| 50 |
+
@pytest.mark.asyncio
|
| 51 |
+
async def test_batch_embed_efficient(self, mock_sentence_transformer, mock_chroma_client):
|
| 52 |
+
"""Batch embedding should call encode with list."""
|
| 53 |
+
# Setup mock for batch return (list of arrays)
|
| 54 |
+
import numpy as np
|
| 55 |
+
|
| 56 |
+
mock_sentence_transformer.encode.return_value = np.array([[0.1, 0.2], [0.3, 0.4]])
|
| 57 |
+
|
| 58 |
+
service = EmbeddingService()
|
| 59 |
+
texts = ["text one", "text two"]
|
| 60 |
+
|
| 61 |
+
batch_results = await service.embed_batch(texts)
|
| 62 |
+
|
| 63 |
+
assert len(batch_results) == 2 # noqa: PLR2004
|
| 64 |
+
assert isinstance(batch_results[0], list)
|
| 65 |
+
mock_sentence_transformer.encode.assert_called_with(texts)
|
| 66 |
+
|
| 67 |
+
@pytest.mark.asyncio
|
| 68 |
+
async def test_add_and_search(self, mock_sentence_transformer, mock_chroma_client):
|
| 69 |
+
"""Should be able to add evidence and search for similar."""
|
| 70 |
+
service = EmbeddingService()
|
| 71 |
+
await service.add_evidence(
|
| 72 |
+
evidence_id="test1",
|
| 73 |
+
content="Metformin activates AMPK pathway",
|
| 74 |
+
metadata={"source": "pubmed"},
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Verify add was called
|
| 78 |
+
mock_collection = mock_chroma_client.create_collection.return_value
|
| 79 |
+
mock_collection.add.assert_called_once()
|
| 80 |
+
|
| 81 |
+
results = await service.search_similar("AMPK activation drugs", n_results=1)
|
| 82 |
+
|
| 83 |
+
# Verify query was called
|
| 84 |
+
mock_collection.query.assert_called_once()
|
| 85 |
+
assert len(results) == 1
|
| 86 |
+
assert results[0]["id"] == "id1"
|
| 87 |
+
|
| 88 |
+
@pytest.mark.asyncio
|
| 89 |
+
async def test_search_similar_empty_collection(
|
| 90 |
+
self, mock_sentence_transformer, mock_chroma_client
|
| 91 |
+
):
|
| 92 |
+
"""Search on empty collection should return empty list, not error."""
|
| 93 |
+
mock_collection = mock_chroma_client.create_collection.return_value
|
| 94 |
+
mock_collection.query.return_value = {
|
| 95 |
+
"ids": [[]],
|
| 96 |
+
"documents": [[]],
|
| 97 |
+
"metadatas": [[]],
|
| 98 |
+
"distances": [[]],
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
service = EmbeddingService()
|
| 102 |
+
results = await service.search_similar("anything", n_results=5)
|
| 103 |
+
assert results == []
|
| 104 |
+
|
| 105 |
+
@pytest.mark.asyncio
|
| 106 |
+
async def test_deduplicate(self, mock_sentence_transformer, mock_chroma_client):
|
| 107 |
+
"""Deduplicate should remove similar items."""
|
| 108 |
+
from src.utils.models import Citation, Evidence
|
| 109 |
+
|
| 110 |
+
service = EmbeddingService()
|
| 111 |
+
|
| 112 |
+
# Mock search to return a match for the first item (duplicate)
|
| 113 |
+
# and no match for the second (unique)
|
| 114 |
+
mock_collection = mock_chroma_client.create_collection.return_value
|
| 115 |
+
|
| 116 |
+
# First call returns match (distance 0.05 < threshold)
|
| 117 |
+
# Second call returns no match or high distance
|
| 118 |
+
mock_collection.query.side_effect = [
|
| 119 |
+
{
|
| 120 |
+
"ids": [["existing_id"]],
|
| 121 |
+
"documents": [["doc"]],
|
| 122 |
+
"metadatas": [[{}]],
|
| 123 |
+
"distances": [[0.05]], # Very similar
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"ids": [[]], # No match
|
| 127 |
+
"documents": [[]],
|
| 128 |
+
"metadatas": [[]],
|
| 129 |
+
"distances": [[]],
|
| 130 |
+
},
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
evidence = [
|
| 134 |
+
Evidence(
|
| 135 |
+
content="Duplicate content",
|
| 136 |
+
citation=Citation(source="web", url="u1", title="t1", date="2024"),
|
| 137 |
+
),
|
| 138 |
+
Evidence(
|
| 139 |
+
content="Unique content",
|
| 140 |
+
citation=Citation(source="web", url="u2", title="t2", date="2024"),
|
| 141 |
+
),
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
unique = await service.deduplicate(evidence, threshold=0.9)
|
| 145 |
+
|
| 146 |
+
# Only the unique one should remain
|
| 147 |
+
assert len(unique) == 1
|
| 148 |
+
assert unique[0].citation.url == "u2"
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|