Buckets:
| """ | |
| OncoRAG SOTA Retriever — State-of-the-Art Medical Retrieval Pipeline. | |
| Implements a multi-stage retrieval architecture: | |
| 1. Bi-Encoder (PubMedBERT) → fast top-K candidates from ChromaDB | |
| 2. Cross-Encoder Re-Ranking → precision-optimised ordering | |
| 3. Distance Threshold Gate → anti-hallucination confidence filter | |
| 4. HyDE Query Expansion → hypothetical document embedding for recall | |
| 5. Token Trimming → context window budget control for Llama 3.1 | |
| Architecture inspired by: | |
| - Nogueira et al. (2019) "Multi-Stage Document Ranking with BERT" | |
| - Gao et al. (2023) "Precise Zero-Shot Dense Retrieval without Relevance Labels" (HyDE) | |
| """ | |
| import logging | |
| import os | |
| import re | |
| from typing import List, Dict, Optional, Tuple | |
| import chromadb | |
| import chromadb.utils.embedding_functions as embedding_functions | |
| import networkx as nx | |
| from .api_clients import CivicAPIClient, ClinicalTrialsClient | |
| logger = logging.getLogger(__name__) | |
| class OncoRAGRetriever: | |
| """ | |
| SOTA Retriever connecting LangGraph agents to ChromaDB. | |
| Pipeline: Query → (optional HyDE) → Bi-Encoder → Cross-Encoder Re-Rank | |
| → Distance Gate → Token Trim → LLM-ready context. | |
| Args: | |
| db_path: Path to the persistent ChromaDB directory. | |
| collection_name: Name of the ChromaDB collection to query. | |
| bi_encoder_model: Sentence-Transformer model for embedding queries. | |
| cross_encoder_model: Cross-Encoder model for re-ranking candidates. | |
| n_candidates: Number of candidates fetched by the bi-encoder (wide net). | |
| n_results: Number of final results returned after re-ranking. | |
| distance_threshold: Maximum cosine distance to accept a result. | |
| Results above this threshold are considered irrelevant. | |
| max_context_chars: Maximum total character budget for LLM context. | |
| """ | |
| # ------------------------------------------------------------------ init | |
| def __init__( | |
| self, | |
| db_path: str = "data/chroma_db", | |
| collection_name: str = "clinical_guidelines", | |
| bi_encoder_model: str = "pritamdeka/S-PubMedBert-MS-MARCO", | |
| cross_encoder_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", | |
| n_candidates: int = 15, | |
| n_results: int = 5, | |
| distance_threshold: float = 0.10, | |
| max_context_chars: int = 6000, | |
| graph_path: str = "data/processed/knowledge_graph.gml", | |
| ): | |
| self.db_path = db_path | |
| self.n_candidates = n_candidates | |
| self.n_results = n_results | |
| self.distance_threshold = distance_threshold | |
| self.max_context_chars = max_context_chars | |
| self.graph_path = graph_path | |
| # --- Bi-Encoder (Stage 1: recall) --- | |
| self._client = chromadb.PersistentClient(path=db_path) | |
| self._emb_fn = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name=bi_encoder_model | |
| ) | |
| self._collection = self._client.get_collection( | |
| name=collection_name, | |
| embedding_function=self._emb_fn, | |
| ) | |
| logger.info( | |
| "Bi-Encoder loaded: %s | Collection: %s (%d docs)", | |
| bi_encoder_model, | |
| collection_name, | |
| self._collection.count(), | |
| ) | |
| # --- Cross-Encoder (Stage 2: precision) --- | |
| self._cross_encoder = None | |
| self._cross_encoder_model_name = cross_encoder_model | |
| # --- SOTA Components (APIs & Graph) --- | |
| self._civic_api = CivicAPIClient() | |
| self._clinical_trials_api = ClinicalTrialsClient() | |
| self._graph: Optional[nx.Graph] = None | |
| # Lazy-load the cross encoder to avoid blocking import time | |
| def _get_cross_encoder(self): | |
| """Return a cached CrossEncoder instance (lazy init).""" | |
| if self._cross_encoder is None: | |
| try: | |
| from sentence_transformers import CrossEncoder | |
| self._cross_encoder = CrossEncoder( | |
| self._cross_encoder_model_name | |
| ) | |
| logger.info( | |
| "Cross-Encoder loaded: %s", | |
| self._cross_encoder_model_name, | |
| ) | |
| except ImportError: | |
| logger.warning( | |
| "sentence-transformers CrossEncoder not available. " | |
| "Falling back to bi-encoder ordering only." | |
| ) | |
| except Exception as exc: | |
| logger.error("Failed to load Cross-Encoder: %s", exc) | |
| return self._cross_encoder | |
| def _get_graph(self) -> Optional[nx.Graph]: | |
| """Return the Knowledge Graph (lazy init).""" | |
| if self._graph is None: | |
| if os.path.exists(self.graph_path): | |
| try: | |
| self._graph = nx.read_gml(self.graph_path) | |
| logger.info("Knowledge Graph loaded from %s", self.graph_path) | |
| except Exception as e: | |
| logger.error("Failed to load Knowledge Graph: %s", e) | |
| else: | |
| logger.warning("Knowledge Graph file not found at %s", self.graph_path) | |
| return self._graph | |
| def _graph_search(self, query_text: str) -> List[Dict]: | |
| """ | |
| Search the Knowledge Graph for clinical relationships. | |
| Matches keywords from query to graph nodes. | |
| """ | |
| graph = self._get_graph() | |
| if not graph: | |
| return [] | |
| query_lower = query_text.lower() | |
| findings = [] | |
| # Simple keyword matching for graph nodes | |
| for node in graph.nodes: | |
| if str(node).lower() in query_lower: | |
| # Find neighbors (related entities) | |
| neighbors = list(graph.neighbors(node)) | |
| for neighbor in neighbors: | |
| edge_data = graph.get_edge_data(node, neighbor) | |
| relation = edge_data.get("relation", "connected_to") | |
| source = edge_data.get("source", "Knowledge Graph") | |
| findings.append({ | |
| "text": f"Graph Finding: {node} {relation} {neighbor}.", | |
| "source": source, | |
| "type": "graph_relation" | |
| }) | |
| return findings | |
| def _external_api_search(self, query_text: str) -> List[Dict]: | |
| """ | |
| Search external clinical APIs (CIViC and ClinicalTrials.gov). | |
| """ | |
| results = [] | |
| # 1. CIViC Search (if query contains gene/variant-like patterns) | |
| # For simplicity, we search for common genes in the query | |
| genes = ["BRAF", "EGFR", "ALK", "KRAS", "NRAS", "HER2", "BRCA1", "BRCA2"] | |
| found_genes = [g for g in genes if g in query_text.upper()] | |
| for gene in found_genes: | |
| # We look for a variant pattern like V600E, T790M | |
| variant_match = re.search(r"[A-Z]\d+[A-Z]", query_text.upper()) | |
| variant = variant_match.group(0) if variant_match else "" | |
| civic_evidence = self._civic_api.search_variant_evidence(gene, variant) | |
| for item in civic_evidence[:2]: # Limit to top 2 | |
| results.append({ | |
| "text": f"CIViC Evidence: Gene {gene} Variant {item.get('variant', {}).get('name')}. Evidence: {item.get('description', 'No description available.')}", | |
| "source": "CIViC Database", | |
| "type": "genomic_evidence" | |
| }) | |
| # 2. ClinicalTrials.gov Search | |
| # We look for condition keywords | |
| conditions = ["Lung Cancer", "Breast Cancer", "Colorectal Cancer", "Hepatocellular Carcinoma", "Melanoma"] | |
| found_conditions = [c for c in conditions if c.lower() in query_text.lower()] | |
| for cond in found_conditions: | |
| trials = self._clinical_trials_api.search_trials(cond) | |
| for trial in trials[:2]: # Limit to top 2 | |
| results.append({ | |
| "text": f"Active Clinical Trial ({trial['nctId']}): {trial['title']}. Summary: {trial['briefSummary']}", | |
| "source": "ClinicalTrials.gov", | |
| "type": "clinical_trial" | |
| }) | |
| return results | |
| # ------------------------------------------------- stage 1: bi-encoder | |
| def _bi_encoder_retrieve( | |
| self, | |
| query_text: str, | |
| n: int, | |
| cancer_type_filter: Optional[str] = None, | |
| ) -> Tuple[List[Dict], List[float]]: | |
| """ | |
| Fetch top-N candidates from ChromaDB using PubMedBERT bi-encoder. | |
| Args: | |
| query_text: The natural-language clinical question. | |
| n: Number of candidate documents to retrieve. | |
| cancer_type_filter: Optional source filename filter. | |
| Returns: | |
| Tuple of (list of result dicts, list of distances). | |
| """ | |
| where_filter = None | |
| if cancer_type_filter: | |
| where_filter = {"source": cancer_type_filter} | |
| results = self._collection.query( | |
| query_texts=[query_text], | |
| n_results=n, | |
| where=where_filter, | |
| ) | |
| candidates: List[Dict] = [] | |
| distances: List[float] = [] | |
| if results and results["documents"]: | |
| for i, doc in enumerate(results["documents"][0]): | |
| meta = results["metadatas"][0][i] if results["metadatas"] else {} | |
| dist = results["distances"][0][i] if results["distances"] else 999.0 | |
| candidates.append({ | |
| "text": doc, | |
| "source": meta.get("source", "Unknown"), | |
| "page": str(meta.get("page", "?")), | |
| "header": meta.get("header", "Unknown"), | |
| }) | |
| distances.append(dist) | |
| return candidates, distances | |
| # ------------------------------------------------- stage 2: cross-encoder | |
| def _cross_encoder_rerank( | |
| self, | |
| query_text: str, | |
| candidates: List[Dict], | |
| ) -> List[Tuple[Dict, float]]: | |
| """ | |
| Re-rank candidates using a Cross-Encoder for precise relevance scoring. | |
| The Cross-Encoder reads (query, document) pairs jointly, producing | |
| far more accurate relevance scores than bi-encoder cosine distance. | |
| Args: | |
| query_text: The original query string. | |
| candidates: List of candidate result dicts from bi-encoder. | |
| Returns: | |
| List of (result_dict, cross_encoder_score) sorted by relevance desc. | |
| """ | |
| cross_enc = self._get_cross_encoder() | |
| if cross_enc is None or not candidates: | |
| # Fallback: return candidates in original order with dummy scores | |
| return [(c, 0.0) for c in candidates] | |
| pairs = [(query_text, c["text"]) for c in candidates] | |
| try: | |
| scores = cross_enc.predict(pairs) | |
| except Exception as exc: | |
| logger.error("Cross-Encoder scoring failed: %s", exc) | |
| return [(c, 0.0) for c in candidates] | |
| scored = list(zip(candidates, scores)) | |
| scored.sort(key=lambda x: x[1], reverse=True) | |
| return scored | |
| # ------------------------------------------------- stage 3: distance gate | |
| def _apply_distance_gate( | |
| self, | |
| candidates: List[Dict], | |
| distances: List[float], | |
| ) -> List[Dict]: | |
| """ | |
| Filter out candidates whose bi-encoder distance exceeds the threshold. | |
| This implements the Anti-Hallucination Distance Gate (Rule #8): | |
| if all results are too far from the query embedding, it is safer | |
| to return nothing than to hallucinate from irrelevant context. | |
| Args: | |
| candidates: List of result dicts. | |
| distances: Corresponding distances from bi-encoder. | |
| Returns: | |
| Filtered list of candidates that pass the gate. | |
| """ | |
| passed: List[Dict] = [] | |
| for cand, dist in zip(candidates, distances): | |
| if dist <= self.distance_threshold: | |
| cand["bi_encoder_distance"] = round(dist, 4) | |
| passed.append(cand) | |
| else: | |
| logger.debug( | |
| "Distance gate rejected (%.4f > %.4f): %s", | |
| dist, | |
| self.distance_threshold, | |
| cand.get("header", "?"), | |
| ) | |
| return passed | |
| # ------------------------------------------------- stage 4: token trim | |
| def _trim_to_budget(self, results: List[Dict]) -> List[Dict]: | |
| """ | |
| Trim the final result list so the total text stays within the | |
| character budget for the LLM context window. | |
| This prevents overflowing Llama 3.1 8B's context when many | |
| long guideline sections are retrieved. | |
| Args: | |
| results: Ordered list of result dicts (best first). | |
| Returns: | |
| Subset of results fitting within max_context_chars. | |
| """ | |
| trimmed: List[Dict] = [] | |
| char_count = 0 | |
| for r in results: | |
| text_len = len(r["text"]) | |
| if char_count + text_len > self.max_context_chars: | |
| # Try to include a truncated version of the next result | |
| remaining = self.max_context_chars - char_count | |
| if remaining > 200: # Only include if meaningful | |
| truncated = r.copy() | |
| truncated["text"] = r["text"][:remaining] + "… [truncated]" | |
| trimmed.append(truncated) | |
| break | |
| trimmed.append(r) | |
| char_count += text_len | |
| return trimmed | |
| # ------------------------------------------------- public: main query | |
| def query( | |
| self, | |
| query_text: str, | |
| n_results: Optional[int] = None, | |
| cancer_type_filter: Optional[str] = None, | |
| use_reranking: bool = True, | |
| ) -> List[Dict[str, str]]: | |
| """ | |
| Full SOTA retrieval pipeline. | |
| Stage 1 — Bi-Encoder: Cast a wide net (n_candidates) via PubMedBERT. | |
| Stage 2 — Distance Gate: Reject low-confidence results. | |
| Stage 3 — Cross-Encoder: Re-rank survivors for precision. | |
| Stage 4 — Token Trim: Fit within LLM context budget. | |
| Args: | |
| query_text: The natural-language clinical question. | |
| n_results: Override the default number of final results. | |
| cancer_type_filter: Optional source filename filter. | |
| use_reranking: Whether to apply cross-encoder re-ranking. | |
| Returns: | |
| A list of dicts with 'text', 'source', 'page', 'header', | |
| and optionally 'cross_encoder_score' / 'bi_encoder_distance'. | |
| """ | |
| k = n_results or self.n_results | |
| # Stage 1: Bi-Encoder wide recall | |
| candidates, distances = self._bi_encoder_retrieve( | |
| query_text, self.n_candidates, cancer_type_filter | |
| ) | |
| logger.info( | |
| "Bi-Encoder returned %d candidates for query: '%s'", | |
| len(candidates), | |
| query_text[:80], | |
| ) | |
| if not candidates: | |
| return [] | |
| # Stage 2: Distance Gate (anti-hallucination) | |
| gated = self._apply_distance_gate(candidates, distances) | |
| logger.info( | |
| "Distance gate passed: %d / %d (threshold=%.2f)", | |
| len(gated), | |
| len(candidates), | |
| self.distance_threshold, | |
| ) | |
| if not gated: | |
| logger.warning( | |
| "All candidates rejected by distance gate — " | |
| "query likely outside guideline coverage." | |
| ) | |
| return [] | |
| # Stage 3: Cross-Encoder Re-ranking | |
| if use_reranking and len(gated) > 1: | |
| scored = self._cross_encoder_rerank(query_text, gated) | |
| # Take top-k after re-ranking | |
| final = [] | |
| for cand, score in scored[:k]: | |
| cand["cross_encoder_score"] = round(float(score), 4) | |
| final.append(cand) | |
| else: | |
| final = gated[:k] | |
| # Stage 4: Token trimming for LLM context budget | |
| final = self._trim_to_budget(final) | |
| # Stage 5: SOTA Expansion (Graph + APIs) | |
| # We append these as high-priority evidence at the top | |
| sota_evidence = [] | |
| # Graph Search | |
| graph_findings = self._graph_search(query_text) | |
| sota_evidence.extend(graph_findings) | |
| # API Search | |
| api_findings = self._external_api_search(query_text) | |
| sota_evidence.extend(api_findings) | |
| # Combine: SOTA evidence comes first as it's often more specific/recent | |
| final = sota_evidence + final | |
| logger.info( | |
| "Final retrieval: %d results (%d SOTA) | (total chars: %d / %d budget)", | |
| len(final), | |
| len(sota_evidence), | |
| sum(len(r["text"]) for r in final), | |
| self.max_context_chars, | |
| ) | |
| return final | |
| # ------------------------------------------------- public: HyDE query | |
| def query_with_hyde( | |
| self, | |
| original_query: str, | |
| hypothetical_answer: str, | |
| n_results: Optional[int] = None, | |
| cancer_type_filter: Optional[str] = None, | |
| ) -> List[Dict[str, str]]: | |
| """ | |
| HyDE (Hypothetical Document Embeddings) retrieval. | |
| Instead of embedding the user's question, we embed a hypothetical | |
| answer generated by the LLM. This dramatically improves recall | |
| for medical synonym matching (e.g. "neoplasia pulmonar" vs | |
| "lung carcinoma"). | |
| The LLM generates a plausible (but unverified) answer, which is | |
| then used as the query for bi-encoder search. The Cross-Encoder | |
| then re-ranks against the ORIGINAL query for precision. | |
| Args: | |
| original_query: The actual clinical question (used for re-ranking). | |
| hypothetical_answer: LLM-generated hypothetical answer (used for embedding). | |
| n_results: Override the default number of final results. | |
| cancer_type_filter: Optional source filename filter. | |
| Returns: | |
| A list of result dicts, same format as query(). | |
| """ | |
| k = n_results or self.n_results | |
| # Stage 1: Bi-Encoder using the hypothetical answer as query | |
| candidates, distances = self._bi_encoder_retrieve( | |
| hypothetical_answer, self.n_candidates, cancer_type_filter | |
| ) | |
| if not candidates: | |
| return [] | |
| # Stage 2: Distance gate | |
| gated = self._apply_distance_gate(candidates, distances) | |
| if not gated: | |
| return [] | |
| # Stage 3: Cross-Encoder re-rank against ORIGINAL query (not HyDE) | |
| if len(gated) > 1: | |
| scored = self._cross_encoder_rerank(original_query, gated) | |
| final = [] | |
| for cand, score in scored[:k]: | |
| cand["cross_encoder_score"] = round(float(score), 4) | |
| final.append(cand) | |
| else: | |
| final = gated[:k] | |
| # Stage 4: Token trim | |
| final = self._trim_to_budget(final) | |
| # Stage 5: SOTA Expansion (Graph + APIs) | |
| # Re-ranking is against ORIGINAL query, so we do expansion here too. | |
| sota_evidence = [] | |
| graph_findings = self._graph_search(original_query) | |
| sota_evidence.extend(graph_findings) | |
| api_findings = self._external_api_search(original_query) | |
| sota_evidence.extend(api_findings) | |
| # Combine: SOTA evidence comes first | |
| final = sota_evidence + final | |
| return final | |
| # ------------------------------------------------- public: format for LLM | |
| def format_context_for_llm(self, results: List[Dict[str, str]]) -> str: | |
| """ | |
| Format retrieval results into a single string suitable for | |
| injection into an LLM prompt as grounding context. | |
| Includes confidence metadata when available. | |
| Args: | |
| results: The list of dicts returned by self.query(). | |
| Returns: | |
| A formatted multi-section string ready for LLM consumption. | |
| """ | |
| if not results: | |
| return "No relevant clinical guidelines found for this query." | |
| sections: List[str] = [] | |
| for i, r in enumerate(results, 1): | |
| header_line = ( | |
| f"[Source {i}] {r['source']} — Page {r['page']} " | |
| f"— Section: {r['header']}" | |
| ) | |
| # Add confidence metadata if present | |
| meta_parts: List[str] = [] | |
| if "cross_encoder_score" in r: | |
| meta_parts.append(f"Relevance: {r['cross_encoder_score']:.2f}") | |
| if "bi_encoder_distance" in r: | |
| meta_parts.append(f"Distance: {r['bi_encoder_distance']:.4f}") | |
| if meta_parts: | |
| header_line += f" | {' | '.join(meta_parts)}" | |
| sections.append(f"{header_line}\n{r['text']}") | |
| return "\n\n---\n\n".join(sections) | |
| # ------------------------------------------------- public: diagnostics | |
| def get_collection_stats(self) -> Dict: | |
| """ | |
| Return basic stats about the underlying ChromaDB collection. | |
| Returns: | |
| Dict with 'count', 'name', and 'db_path'. | |
| """ | |
| return { | |
| "count": self._collection.count(), | |
| "name": self._collection.name, | |
| "db_path": self.db_path, | |
| "distance_threshold": self.distance_threshold, | |
| "max_context_chars": self.max_context_chars, | |
| } | |
Xet Storage Details
- Size:
- 21.5 kB
- Xet hash:
- c5cc4f3894803cca5763abaf4694a136fe1c6e97e9ee00274c2ce1b87cd5a649
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.