Spaces:
Sleeping
Sleeping
| """ | |
| Cross-encoder re-ranking for improved retrieval precision. | |
| THE DIFFERENCE BETWEEN BI-ENCODER AND CROSS-ENCODER: | |
| Bi-encoder (what BGE does): | |
| embed(query) → vector_q | |
| embed(chunk) → vector_c | |
| score = cosine(vector_q, vector_c) | |
| Query and chunk are embedded INDEPENDENTLY. | |
| Fast (vectors pre-computed), but loses interaction signal. | |
| Cross-encoder (what we use for re-ranking): | |
| score = model(query + [SEP] + chunk) | |
| Query and chunk are processed TOGETHER by the model. | |
| The model can see how query tokens relate to chunk tokens. | |
| Slower (cannot pre-compute), but much more accurate. | |
| THE TWO-STAGE PATTERN: | |
| Stage 1 (Retrieval): Bi-encoder -> top-20 candidates (fast, approximate) | |
| Stage 2 (Re-ranking): Cross-encoder -> re-score top-20 (slow, accurate) | |
| We only run the expensive cross-encoder on 20 candidates, | |
| not all 15,664 chunks. This gives us accuracy without | |
| paying the full cost for every chunk. | |
| MODEL: cross-encoder/ms-marco-MiniLM-L-6-v2 | |
| - Trained on MS MARCO passage retrieval dataset (500K+ queries) | |
| - MiniLM architecture: fast on CPU | |
| - Output: relevance score (-inf to +inf, higher = more relevant) | |
| - Size: ~80MB | |
| """ | |
| import logging | |
| logging.getLogger("sentence_transformers").setLevel(logging.ERROR) | |
| from sentence_transformers import CrossEncoder | |
| from src.utils.logger import get_logger | |
| logger = get_logger(__name__) | |
| RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| class CrossEncoderReranker: | |
| """ | |
| Re-ranks retrieved chunks using a cross-encoder model. | |
| """ | |
| def __init__(self, model_name: str = RERANKER_MODEL): | |
| self._model = None | |
| self._model_name = model_name | |
| logger.info(f"CrossEncoderReranker initialized: {model_name}") | |
| def model(self) -> CrossEncoder: | |
| """Lazy-load cross-encoder model.""" | |
| if self._model is None: | |
| logger.info(f"Loading cross-encoder: {self._model_name}") | |
| self._model = CrossEncoder( | |
| self._model_name, | |
| max_length = 512 # Max tokens for query+chunk combined | |
| ) | |
| logger.info("Cross-encoder loaded") | |
| return self._model | |
| def rerank( | |
| self, | |
| query: str, | |
| results: list[dict], | |
| top_k: int = 5 | |
| ) -> list[dict]: | |
| """ | |
| Re-rank a list of retrieved chunks using cross-encoder scoring. | |
| Args: | |
| query: Original user query | |
| results: List of retrieved chunk dicts (from hybrid retriever) | |
| top_k: How many top results to return after re-ranking | |
| Returns: | |
| Top-k results sorted by cross-encoder relevance score | |
| WHAT THE CROSS-ENCODER SEES: | |
| Input: "[CLS] how does attention work? [SEP] The transformer | |
| architecture uses scaled dot-product attention where | |
| queries, keys and values are computed... [SEP]" | |
| Output: 8.3 (high relevance) | |
| vs. | |
| Input: "[CLS] how does attention work? [SEP] UAV delivery | |
| systems require multi-agent coordination... [SEP]" | |
| Output: -2.1 (low relevance) | |
| The model learned these relevance patterns from 500K+ | |
| human-labeled query-passage pairs in MS MARCO. | |
| """ | |
| if not results: | |
| return [] | |
| # Build (query, chunk_text) pairs for batch scoring | |
| pairs = [ | |
| (query, r.get("text", "")) | |
| for r in results | |
| ] | |
| # Score all pairs in one batch | |
| # predict() returns numpy array of relevance scores | |
| scores = self.model.predict( | |
| pairs, | |
| show_progress_bar = False, | |
| batch_size = 32, | |
| ) | |
| # Attach cross_encoder score to each result | |
| for result, score in zip(results, scores): | |
| result["ce_score"] = round(float(score), 4) | |
| # Sort by cross-encoder score (descending) | |
| reranked = sorted(results, key = lambda x: x["ce_score"], reverse = True) | |
| logger.debug( | |
| f"Re-ranked {len(results)} -> top-{top_k}. " | |
| f"Score range: [{reranked[-1]['ce_score']:.2f}, " | |
| f"{reranked[0]['ce_score']:.2f}]" | |
| ) | |
| return reranked[:top_k] | |
| def diversity_filter(results: list[dict], max_per_paper: int = 2) -> list[dict]: | |
| """ | |
| Ensure no single paper dominates the results. | |
| As you saw in test_search.py - the same paper appeared twice | |
| in top-3. This function limits results to max_per_paper | |
| chunks from any single paper. | |
| Args: | |
| results: List of result dicts (sorted by relevance) | |
| max_per_paper: Maximum chunks allowed from the same paper | |
| Returns: | |
| Filtered list maintaining original relevance order | |
| WHY THIS MATTERS FOR USER EXPERIENCE: | |
| User asks: "how does attention work?" | |
| Without diversity filter: 3 chunks from same attention paper | |
| With diversity filter: 1-2 chunks each from 3 different papers | |
| The second response is richer - multiple perspectives, | |
| multiple research groups, more comprehensive coverage. | |
| """ | |
| seen_papers: dict[str, int] = {} | |
| filtered = [] | |
| for result in results: | |
| paper_id = result.get("paper_id", "unknown") | |
| count = seen_papers.get(paper_id, 0) | |
| if count < max_per_paper: | |
| filtered.append(result) | |
| seen_papers[paper_id] = count + 1 | |
| return filtered |