Spaces:
Running
Running
File size: 5,533 Bytes
f780124 11f2a7a f780124 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """
Cross-encoder re-ranking for improved retrieval precision.
THE DIFFERENCE BETWEEN BI-ENCODER AND CROSS-ENCODER:
Bi-encoder (what BGE does):
embed(query) → vector_q
embed(chunk) → vector_c
score = cosine(vector_q, vector_c)
Query and chunk are embedded INDEPENDENTLY.
Fast (vectors pre-computed), but loses interaction signal.
Cross-encoder (what we use for re-ranking):
score = model(query + [SEP] + chunk)
Query and chunk are processed TOGETHER by the model.
The model can see how query tokens relate to chunk tokens.
Slower (cannot pre-compute), but much more accurate.
THE TWO-STAGE PATTERN:
Stage 1 (Retrieval): Bi-encoder -> top-20 candidates (fast, approximate)
Stage 2 (Re-ranking): Cross-encoder -> re-score top-20 (slow, accurate)
We only run the expensive cross-encoder on 20 candidates,
not all 15,664 chunks. This gives us accuracy without
paying the full cost for every chunk.
MODEL: cross-encoder/ms-marco-MiniLM-L-6-v2
- Trained on MS MARCO passage retrieval dataset (500K+ queries)
- MiniLM architecture: fast on CPU
- Output: relevance score (-inf to +inf, higher = more relevant)
- Size: ~80MB
"""
import logging
logging.getLogger("sentence_transformers").setLevel(logging.ERROR)
from sentence_transformers import CrossEncoder
from src.utils.logger import get_logger
logger = get_logger(__name__)
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
class CrossEncoderReranker:
"""
Re-ranks retrieved chunks using a cross-encoder model.
"""
def __init__(self, model_name: str = RERANKER_MODEL):
self._model = None
self._model_name = model_name
logger.info(f"CrossEncoderReranker initialized: {model_name}")
@property
def model(self) -> CrossEncoder:
"""Lazy-load cross-encoder model."""
if self._model is None:
logger.info(f"Loading cross-encoder: {self._model_name}")
self._model = CrossEncoder(
self._model_name,
max_length = 512 # Max tokens for query+chunk combined
)
logger.info("Cross-encoder loaded")
return self._model
def rerank(
self,
query: str,
results: list[dict],
top_k: int = 5
) -> list[dict]:
"""
Re-rank a list of retrieved chunks using cross-encoder scoring.
Args:
query: Original user query
results: List of retrieved chunk dicts (from hybrid retriever)
top_k: How many top results to return after re-ranking
Returns:
Top-k results sorted by cross-encoder relevance score
WHAT THE CROSS-ENCODER SEES:
Input: "[CLS] how does attention work? [SEP] The transformer
architecture uses scaled dot-product attention where
queries, keys and values are computed... [SEP]"
Output: 8.3 (high relevance)
vs.
Input: "[CLS] how does attention work? [SEP] UAV delivery
systems require multi-agent coordination... [SEP]"
Output: -2.1 (low relevance)
The model learned these relevance patterns from 500K+
human-labeled query-passage pairs in MS MARCO.
"""
if not results:
return []
# Build (query, chunk_text) pairs for batch scoring
pairs = [
(query, r.get("text", ""))
for r in results
]
# Score all pairs in one batch
# predict() returns numpy array of relevance scores
scores = self.model.predict(
pairs,
show_progress_bar = False,
batch_size = 32,
)
# Attach cross_encoder score to each result
for result, score in zip(results, scores):
result["ce_score"] = round(float(score), 4)
# Sort by cross-encoder score (descending)
reranked = sorted(results, key = lambda x: x["ce_score"], reverse = True)
logger.debug(
f"Re-ranked {len(results)} -> top-{top_k}. "
f"Score range: [{reranked[-1]['ce_score']:.2f}, "
f"{reranked[0]['ce_score']:.2f}]"
)
return reranked[:top_k]
def diversity_filter(results: list[dict], max_per_paper: int = 2) -> list[dict]:
"""
Ensure no single paper dominates the results.
As you saw in test_search.py - the same paper appeared twice
in top-3. This function limits results to max_per_paper
chunks from any single paper.
Args:
results: List of result dicts (sorted by relevance)
max_per_paper: Maximum chunks allowed from the same paper
Returns:
Filtered list maintaining original relevance order
WHY THIS MATTERS FOR USER EXPERIENCE:
User asks: "how does attention work?"
Without diversity filter: 3 chunks from same attention paper
With diversity filter: 1-2 chunks each from 3 different papers
The second response is richer - multiple perspectives,
multiple research groups, more comprehensive coverage.
"""
seen_papers: dict[str, int] = {}
filtered = []
for result in results:
paper_id = result.get("paper_id", "unknown")
count = seen_papers.get(paper_id, 0)
if count < max_per_paper:
filtered.append(result)
seen_papers[paper_id] = count + 1
return filtered |