Subhadip007's picture
fix: f-string syntax error in reranker block for python 3.10 compatibility
11f2a7a
"""
Cross-encoder re-ranking for improved retrieval precision.
THE DIFFERENCE BETWEEN BI-ENCODER AND CROSS-ENCODER:
Bi-encoder (what BGE does):
embed(query) → vector_q
embed(chunk) → vector_c
score = cosine(vector_q, vector_c)
Query and chunk are embedded INDEPENDENTLY.
Fast (vectors pre-computed), but loses interaction signal.
Cross-encoder (what we use for re-ranking):
score = model(query + [SEP] + chunk)
Query and chunk are processed TOGETHER by the model.
The model can see how query tokens relate to chunk tokens.
Slower (cannot pre-compute), but much more accurate.
THE TWO-STAGE PATTERN:
Stage 1 (Retrieval): Bi-encoder -> top-20 candidates (fast, approximate)
Stage 2 (Re-ranking): Cross-encoder -> re-score top-20 (slow, accurate)
We only run the expensive cross-encoder on 20 candidates,
not all 15,664 chunks. This gives us accuracy without
paying the full cost for every chunk.
MODEL: cross-encoder/ms-marco-MiniLM-L-6-v2
- Trained on MS MARCO passage retrieval dataset (500K+ queries)
- MiniLM architecture: fast on CPU
- Output: relevance score (-inf to +inf, higher = more relevant)
- Size: ~80MB
"""
import logging
logging.getLogger("sentence_transformers").setLevel(logging.ERROR)
from sentence_transformers import CrossEncoder
from src.utils.logger import get_logger
logger = get_logger(__name__)
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
class CrossEncoderReranker:
"""
Re-ranks retrieved chunks using a cross-encoder model.
"""
def __init__(self, model_name: str = RERANKER_MODEL):
self._model = None
self._model_name = model_name
logger.info(f"CrossEncoderReranker initialized: {model_name}")
@property
def model(self) -> CrossEncoder:
"""Lazy-load cross-encoder model."""
if self._model is None:
logger.info(f"Loading cross-encoder: {self._model_name}")
self._model = CrossEncoder(
self._model_name,
max_length = 512 # Max tokens for query+chunk combined
)
logger.info("Cross-encoder loaded")
return self._model
def rerank(
self,
query: str,
results: list[dict],
top_k: int = 5
) -> list[dict]:
"""
Re-rank a list of retrieved chunks using cross-encoder scoring.
Args:
query: Original user query
results: List of retrieved chunk dicts (from hybrid retriever)
top_k: How many top results to return after re-ranking
Returns:
Top-k results sorted by cross-encoder relevance score
WHAT THE CROSS-ENCODER SEES:
Input: "[CLS] how does attention work? [SEP] The transformer
architecture uses scaled dot-product attention where
queries, keys and values are computed... [SEP]"
Output: 8.3 (high relevance)
vs.
Input: "[CLS] how does attention work? [SEP] UAV delivery
systems require multi-agent coordination... [SEP]"
Output: -2.1 (low relevance)
The model learned these relevance patterns from 500K+
human-labeled query-passage pairs in MS MARCO.
"""
if not results:
return []
# Build (query, chunk_text) pairs for batch scoring
pairs = [
(query, r.get("text", ""))
for r in results
]
# Score all pairs in one batch
# predict() returns numpy array of relevance scores
scores = self.model.predict(
pairs,
show_progress_bar = False,
batch_size = 32,
)
# Attach cross_encoder score to each result
for result, score in zip(results, scores):
result["ce_score"] = round(float(score), 4)
# Sort by cross-encoder score (descending)
reranked = sorted(results, key = lambda x: x["ce_score"], reverse = True)
logger.debug(
f"Re-ranked {len(results)} -> top-{top_k}. "
f"Score range: [{reranked[-1]['ce_score']:.2f}, "
f"{reranked[0]['ce_score']:.2f}]"
)
return reranked[:top_k]
def diversity_filter(results: list[dict], max_per_paper: int = 2) -> list[dict]:
"""
Ensure no single paper dominates the results.
As you saw in test_search.py - the same paper appeared twice
in top-3. This function limits results to max_per_paper
chunks from any single paper.
Args:
results: List of result dicts (sorted by relevance)
max_per_paper: Maximum chunks allowed from the same paper
Returns:
Filtered list maintaining original relevance order
WHY THIS MATTERS FOR USER EXPERIENCE:
User asks: "how does attention work?"
Without diversity filter: 3 chunks from same attention paper
With diversity filter: 1-2 chunks each from 3 different papers
The second response is richer - multiple perspectives,
multiple research groups, more comprehensive coverage.
"""
seen_papers: dict[str, int] = {}
filtered = []
for result in results:
paper_id = result.get("paper_id", "unknown")
count = seen_papers.get(paper_id, 0)
if count < max_per_paper:
filtered.append(result)
seen_papers[paper_id] = count + 1
return filtered