File size: 5,533 Bytes
f780124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11f2a7a
 
f780124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
Cross-encoder re-ranking for improved retrieval precision.

THE DIFFERENCE BETWEEN BI-ENCODER AND CROSS-ENCODER:

Bi-encoder (what BGE does):
    embed(query) → vector_q
    embed(chunk) → vector_c
    score = cosine(vector_q, vector_c)
    
    Query and chunk are embedded INDEPENDENTLY.
    Fast (vectors pre-computed), but loses interaction signal.

Cross-encoder (what we use for re-ranking):
    score = model(query + [SEP] + chunk)
    
    Query and chunk are processed TOGETHER by the model.
    The model can see how query tokens relate to chunk tokens.
    Slower (cannot pre-compute), but much more accurate.

THE TWO-STAGE PATTERN:
    Stage 1 (Retrieval):   Bi-encoder -> top-20 candidates (fast, approximate)
    Stage 2 (Re-ranking):  Cross-encoder -> re-score top-20 (slow, accurate)
    
    We only run the expensive cross-encoder on 20 candidates,
    not all 15,664 chunks. This gives us accuracy without
    paying the full cost for every chunk.

MODEL: cross-encoder/ms-marco-MiniLM-L-6-v2
    - Trained on MS MARCO passage retrieval dataset (500K+ queries)
    - MiniLM architecture: fast on CPU
    - Output: relevance score (-inf to +inf, higher = more relevant)
    - Size: ~80MB
"""

import logging
logging.getLogger("sentence_transformers").setLevel(logging.ERROR)

from sentence_transformers import CrossEncoder
from src.utils.logger import get_logger

logger = get_logger(__name__)

RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" 


class CrossEncoderReranker:
    """
    Re-ranks retrieved chunks using a cross-encoder model.
    """


    def __init__(self, model_name: str = RERANKER_MODEL):
        self._model      = None
        self._model_name = model_name
        logger.info(f"CrossEncoderReranker initialized: {model_name}")

    @property
    def model(self) -> CrossEncoder:
        """Lazy-load cross-encoder model.""" 
        if self._model is None:
            logger.info(f"Loading cross-encoder: {self._model_name}")
            self._model = CrossEncoder(
                self._model_name,
                max_length = 512    # Max tokens for query+chunk combined
            )
            logger.info("Cross-encoder loaded")

        return self._model


    def rerank(
        self,
        query:      str,
        results:    list[dict],
        top_k:      int = 5
    ) -> list[dict]:
        """
        Re-rank a list of retrieved chunks using cross-encoder scoring.

        Args:
            query:   Original user query
            results: List of retrieved chunk dicts (from hybrid retriever)
            top_k:   How many top results to return after re-ranking

        Returns:
            Top-k results sorted by cross-encoder relevance score

        WHAT THE CROSS-ENCODER SEES:
            Input: "[CLS] how does attention work? [SEP] The transformer
                    architecture uses scaled dot-product attention where
                    queries, keys and values are computed... [SEP]"
            Output: 8.3  (high relevance)

            vs.

            Input: "[CLS] how does attention work? [SEP] UAV delivery
                    systems require multi-agent coordination... [SEP]"
            Output: -2.1  (low relevance)

        The model learned these relevance patterns from 500K+
        human-labeled query-passage pairs in MS MARCO.
        """

        if not results:
            return []

        # Build (query, chunk_text) pairs for batch scoring
        pairs = [
            (query, r.get("text", ""))
            for r in results
        ]

        # Score all pairs in one batch
        # predict() returns numpy array of relevance scores
        scores = self.model.predict(
            pairs,
            show_progress_bar = False,
            batch_size = 32,
        )

        # Attach cross_encoder score to each result
        for result, score in zip(results, scores):
            result["ce_score"] = round(float(score), 4)

        # Sort by cross-encoder score (descending)
        reranked = sorted(results, key = lambda x: x["ce_score"], reverse = True)

        logger.debug(
            f"Re-ranked {len(results)} -> top-{top_k}. "
            f"Score range: [{reranked[-1]['ce_score']:.2f}, "
            f"{reranked[0]['ce_score']:.2f}]"
        )


        return reranked[:top_k]



def diversity_filter(results: list[dict], max_per_paper: int = 2) -> list[dict]:
    """
    Ensure no single paper dominates the results.

    As you saw in test_search.py - the same paper appeared twice
    in top-3. This function limits results to max_per_paper
    chunks from any single paper.

    Args:
        results:       List of result dicts (sorted by relevance)
        max_per_paper: Maximum chunks allowed from the same paper

    Returns:
        Filtered list maintaining original relevance order

    WHY THIS MATTERS FOR USER EXPERIENCE:
        User asks: "how does attention work?"
        Without diversity filter: 3 chunks from same attention paper
        With diversity filter: 1-2 chunks each from 3 different papers

        The second response is richer - multiple perspectives,
        multiple research groups, more comprehensive coverage.
    """

    seen_papers: dict[str, int] = {}
    filtered = []

    for result in results:
        paper_id    = result.get("paper_id", "unknown")
        count       = seen_papers.get(paper_id, 0)

        if count < max_per_paper:
            filtered.append(result)
            seen_papers[paper_id] = count + 1

    
    return filtered