Sentence Similarity
sentence-transformers
Safetensors
Hebrew
hebrew
semantic-retrieval
information-retrieval
dense-retrieval
reranking
rrf
competition
Instructions to use HebArabNlpProject/Semantic-Retrieval-2nd-place with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use HebArabNlpProject/Semantic-Retrieval-2nd-place with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("HebArabNlpProject/Semantic-Retrieval-2nd-place") sentences = [ "The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium." ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [3, 3] - Notebooks
- Google Colab
- Kaggle
| import os, re, math, unicodedata, time, json, hashlib, importlib.util | |
| from collections import defaultdict, Counter | |
| from typing import List, Tuple, Dict, Optional | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification | |
| from sentence_transformers import SentenceTransformer | |
| import sys, pathlib | |
| HERE = pathlib.Path(__file__).resolve().parent | |
| if str(HERE) not in sys.path: | |
| sys.path.insert(0, str(HERE)) | |
| # ======================= Tunables ======================= | |
| BM25_K1 = 1.3 | |
| BM25_B = 0.7 | |
| RRF_K = 35 # RRF constant | |
| CE_MAXLEN = 640 | |
| CE_BATCH = 128 | |
| TOP_BM25 = TOP_E5 = TOP_GEMMA = CE_POOL = 190 | |
| # Weighted RRF stage-1 fusion (BM25 + E5 + Gemma) | |
| WRRF_BM25_W = 1.0 | |
| WRRF_E5_W = 1.2 | |
| WRRF_GEMMA_W= 1.4 | |
| # Weighting for the final (reranker) fusion | |
| FINAL_SCORE_BGE_WEIGHT = .07 | |
| # Model & cache dirs | |
| USE_CACHE = True | |
| BGE_DIR = r"models/bge-reranker-hsrc-pairwise-rrf-V1.4".strip() | |
| E5_DIR = r"models/e5-large-ft_v6".strip() | |
| E5_EVAL_CACHE_DIR = r"".strip() if USE_CACHE else r"" | |
| GEMMA_DIR = r"models/multilingual-e5-large".strip() | |
| GEMMA_EVAL_CACHE_DIR = r"".strip() if USE_CACHE else r"" | |
| # Gemma dtype & max length (SentenceTransformers truncation) | |
| PREFER_BF16_GEMMA = True | |
| GEMMA_MAX_TOK = 512 | |
| # GEMMA_QUERY_TASK = "search result" # used only when we must fall back to prompt=... | |
| # ======================= Silence under eval_std ============================ | |
| _EVAL_SILENT = os.environ.get("EVAL_STD_MODE","").strip() == "1" | |
| def _log(msg: str): | |
| if not _EVAL_SILENT: | |
| print(msg, flush=True) | |
| # ======================= Normalization / Tokenization ======================= | |
| # Priority: 1) Relative import, 2) sys.path, 3) Dynamic import, 4) Fallback | |
| try: | |
| from .text_utils import ( # type: ignore | |
| tok_he, norm_bm25, | |
| norm_e5_query, norm_e5_passage, | |
| norm_gemma_query, norm_gemma_passage, | |
| norm_bge_query, norm_bge_passage | |
| ) | |
| _log("[Init] Loaded text_utils (relative import).") | |
| except (ImportError, ModuleNotFoundError): | |
| try: | |
| from text_utils import ( | |
| tok_he, norm_bm25, | |
| norm_e5_query, norm_e5_passage, | |
| norm_gemma_query, norm_gemma_passage, | |
| norm_bge_query, norm_bge_passage | |
| ) | |
| _log("[Init] Loaded text_utils (sys.path import).") | |
| except (ImportError, ModuleNotFoundError): | |
| try: | |
| spec_path = HERE / "text_utils.py" | |
| if not spec_path.is_file(): | |
| raise FileNotFoundError(f"{spec_path} not found.") | |
| spec = importlib.util.spec_from_file_location("text_utils", spec_path) | |
| text_utils_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(text_utils_module) | |
| tok_he = text_utils_module.tok_he | |
| norm_bm25 = text_utils_module.norm_bm25 | |
| norm_e5_query = text_utils_module.norm_e5_query | |
| norm_e5_passage = text_utils_module.norm_e5_passage | |
| # Gemma-specific normalizers if present; fallback to e5 norms | |
| norm_gemma_query = getattr(text_utils_module, "norm_gemma_query", text_utils_module.norm_e5_query) | |
| norm_gemma_passage = getattr(text_utils_module, "norm_gemma_passage", text_utils_module.norm_e5_passage) | |
| norm_bge_query = text_utils_module.norm_bge_query | |
| norm_bge_passage = text_utils_module.norm_bge_passage | |
| _log("[Init] Loaded text_utils (dynamic import).") | |
| except Exception: | |
| _log("[Init] `text_utils` not found. Using generic Hebrew-friendly normalizers for all components.") | |
| HEB_PREFIXES = ("ו","ה","ב","ל","כ","מ","ש") | |
| STOPWORDS = {} | |
| def _generic_norm_he(s:str)->str: | |
| if not s: return "" | |
| s=unicodedata.normalize("NFKC",s) | |
| s=re.sub(r"[\u0591-\u05BD\u05BF-\u05C7]","",s) | |
| s=(s.replace("״",'"').replace("׳","'") | |
| .replace("”",'"').replace("“",'"') | |
| .replace("–","-").replace("—","-")) | |
| return re.sub(r"\s+"," ",s).strip() | |
| norm_bm25 = norm_e5_query = norm_e5_passage = _generic_norm_he | |
| norm_gemma_query = norm_gemma_passage = _generic_norm_he | |
| norm_bge_query = norm_bge_passage = _generic_norm_he | |
| def tok_he(text: str) -> List[str]: | |
| s = norm_bm25(text) | |
| toks = re.findall(r"[A-Za-z0-9\u0590-\u05FF]+", s) | |
| out=[] | |
| for t in toks: | |
| if len(t)>3 and t[0] in HEB_PREFIXES: out.append(t[1:]) | |
| out.append(t) | |
| return [t for t in out if t not in STOPWORDS] | |
| # =========================== BM25 Backends ================================ | |
| get_bm25_backend = None | |
| _HAS_BM25_BACKENDS = False | |
| try: | |
| from .bm25_backends import get_bm25_backend | |
| _HAS_BM25_BACKENDS = True | |
| _log("[Init] Loaded bm25_backends (relative import).") | |
| except (ImportError, ModuleNotFoundError): | |
| try: | |
| from bm25_backends import get_bm25_backend | |
| _HAS_BM25_BACKENDS = True | |
| _log("[Init] Loaded bm25_backends (sys.path import).") | |
| except (ImportError, ModuleNotFoundError): | |
| try: | |
| spec_path = HERE / "bm25_backends.py" | |
| if not spec_path.is_file(): | |
| raise FileNotFoundError(f"{spec_path} not found.") | |
| spec = importlib.util.spec_from_file_location("bm25_backends", spec_path) | |
| bm25_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(bm25_module) | |
| get_bm25_backend = bm25_module.get_bm25_backend | |
| _HAS_BM25_BACKENDS = True | |
| _log("[Init] Loaded bm25_backends (dynamic import).") | |
| except Exception as e: | |
| _log(f"[Init] Could not load bm25_backends.py ({e}). Will use built-in fallbacks.") | |
| pass | |
| class _LocalBM25SBackend: | |
| """Minimal local wrapper for bm25s when bm25_backends.py is missing.""" | |
| def __init__(self, tokenizer, k1: float = 1.3, b: float = 0.7, logger=_log): | |
| import bm25s | |
| self._BM25 = bm25s.BM25 | |
| self.tokenizer = tokenizer | |
| self.k1, self.b = k1, b | |
| self._bm25 = None | |
| self.doc_ids: List[str] = [] | |
| self._logger = logger | |
| def name(self)->str: | |
| return f"LocalBM25S(k1={self.k1}, b={self.b})" | |
| def build(self, ids: List[str], texts: List[str]): | |
| t0=time.time() | |
| self.doc_ids = list(ids) | |
| tokenized = [self.tokenizer(t) for t in texts] | |
| self._bm25 = self._BM25(k1=self.k1, b=self.b) | |
| self._bm25.index(tokenized) | |
| if self._logger: self._logger(f"[{self.name}] Indexed {len(self.doc_ids):,} docs in {time.time()-t0:.2f}s") | |
| def search(self, query: str, topk: int = 300) -> List[str]: | |
| terms = self.tokenizer(query) | |
| if not terms or self._bm25 is None: return [] | |
| k = min(topk, len(self.doc_ids)) | |
| idxs, scores = self._bm25.retrieve([terms], k=k) | |
| idxs, scores = idxs[0], scores[0] | |
| mask = np.isfinite(scores) & (scores > 0) | |
| idxs, scores = idxs[mask], scores[mask] | |
| if idxs.size == 0: return [] | |
| order = np.lexsort((idxs, -scores)) | |
| idxs = idxs[order] | |
| return [self.doc_ids[int(i)] for i in idxs] | |
| class _DeterministicBM25Backend: | |
| """Embedded pure-Python deterministic BM25. Guaranteed fallback.""" | |
| def __init__(self, tokenizer, k1: float = 1.3, b: float = 0.7, logger=_log): | |
| self.tokenizer=tokenizer; self.k1=k1; self.b=b | |
| self.doc_ids: List[str]=[]; self.N=0; self.avgdl=0.0 | |
| self.doc_lens=None; self.vocab: Dict[str,int]={} | |
| self.postings: Dict[int,Tuple[np.ndarray,np.ndarray]]={} | |
| self.idf=None; self._logger=logger | |
| def name(self)->str: | |
| return f"DeterministicBM25(k1={self.k1}, b={self.b})" | |
| def build(self, ids: List[str], texts: List[str]): | |
| self.doc_ids=list(ids); self.N=len(ids) | |
| lens=np.zeros(self.N,dtype=np.int32) | |
| tmp=defaultdict(list) | |
| t0=time.time() | |
| for i, text in enumerate(texts): | |
| terms=self.tokenizer(text); lens[i]=len(terms) | |
| if not terms: continue | |
| ctr=Counter(terms) | |
| for t,tf in ctr.items(): | |
| tid=self.vocab.setdefault(t, len(self.vocab)) | |
| tmp[tid].append((i, tf)) | |
| self.doc_lens=lens; self.avgdl=float(np.maximum(1,lens).mean()) | |
| V=len(self.vocab); self.idf=np.zeros(V,dtype=np.float32) | |
| self.postings={} | |
| for tid, pairs in tmp.items(): | |
| docs=np.array([d for d,_ in pairs],dtype=np.int32) | |
| tfs =np.array([tf for _,tf in pairs],dtype=np.float32) | |
| df=float(len(docs)) | |
| idf=math.log((self.N-df+0.5)/(df+0.5)+1.0) | |
| self.idf[tid]=idf | |
| self.postings[tid]=(docs,tfs) | |
| if self._logger: self._logger(f"[{self.name}] Indexed {self.N:,} docs in {time.time()-t0:.2f}s") | |
| def search(self, query: str, topk: int = 300) -> List[str]: | |
| terms=self.tokenizer(query) | |
| if not terms: return [] | |
| seen: Dict[int,float] = {} | |
| for t in terms: | |
| tid=self.vocab.get(t) | |
| if tid is None: continue | |
| idf=float(self.idf[tid]) | |
| docs,tfs=self.postings[tid] | |
| denom=tfs + self.k1*(1-self.b + self.b*(self.doc_lens[docs]/self.avgdl)) | |
| contrib = idf * (tfs*(self.k1+1)) / denom | |
| for d, c in zip(docs, contrib): | |
| seen[d]=seen.get(d,0.0)+float(c) | |
| if not seen: return [] | |
| idx=np.fromiter(seen.keys(),dtype=np.int32) | |
| scs=np.fromiter(seen.values(),dtype=np.float32) | |
| k=min(topk,len(scs)) | |
| order = np.lexsort((idx, -scs)) | |
| order = order[:k] | |
| idx = idx[order] | |
| return [self.doc_ids[i] for i in idx] | |
| class BM25Index: | |
| """Unified BM25 wrapper. Returns List[str] of doc IDs.""" | |
| def __init__(self, k1=1.3, b=0.70, logger=_log): | |
| self.k1, self.b = k1, b | |
| self.doc_ids: List[str] = [] | |
| self._be = None; self._backend_name = "unset"; self._logger = logger | |
| def build(self, ids: List[str], texts_norm: List[str]): | |
| if _HAS_BM25_BACKENDS and callable(get_bm25_backend): | |
| try: | |
| self._be = get_bm25_backend(use_bm25s=True, tokenizer=tok_he, k1=self.k1, b=self.b, logger=self._logger) | |
| self._be.build(ids, texts_norm) | |
| self.doc_ids = list(self._be.doc_ids) | |
| self._backend_name = f"{self._be.name} (bm25_backends.py)" | |
| if self._logger: self._logger(f"[BM25] Using backend: {self._backend_name}") | |
| return | |
| except Exception as e: | |
| if self._logger: self._logger(f"[BM25] bm25_backends failed ({e}). Trying direct bm25s...)") | |
| try: | |
| self._be = _LocalBM25SBackend(tok_he, k1=self.k1, b=self.b, logger=self._logger) | |
| self._be.build(ids, texts_norm) | |
| self.doc_ids = list(self._be.doc_ids) | |
| self._backend_name = f"{self._be.name} (direct)" | |
| if self._logger: self._logger(f"[BM25] Using backend: {self._backend_name}") | |
| return | |
| except Exception as e: | |
| if self._logger: self._logger(f"[BM25] bm25s unavailable ({e}). Falling back to pure-Python).") | |
| self._be = _DeterministicBM25Backend(tok_he, k1=self.k1, b=self.b, logger=self._logger) | |
| self._be.build(ids, texts_norm) | |
| self.doc_ids = list(self._be.doc_ids) | |
| self._backend_name = f"{self._be.name} (embedded)" | |
| if self._logger: self._logger(f"[BM25] Using backend: {self._backend_name}") | |
| def search(self, query: str, topk: int = 200) -> List[str]: | |
| if self._be is None: return [] | |
| return self._be.search(query, topk=topk) | |
| # ======================= Model Path Resolution ======================= | |
| def _resolve_model_path(primary_path: str, fallback_names: List[str]) -> str: | |
| """ | |
| Resolves a model path: checks primary_path, then HERE/models, HERE, CWD, CWD/models. | |
| Falls back to first fallback name (HF id/path). | |
| """ | |
| if primary_path and pathlib.Path(primary_path).is_dir(): | |
| return primary_path | |
| base_dirs = [HERE / "models", HERE, pathlib.Path.cwd(), pathlib.Path.cwd() / "models"] | |
| for base in base_dirs: | |
| for name in fallback_names: | |
| candidate = base / name | |
| if candidate.is_dir(): | |
| return str(candidate) | |
| return fallback_names[0] | |
| def model_name_key(s: str) -> str: | |
| if not s: | |
| return "" | |
| s = s.strip().rstrip("/\\") | |
| last = re.split(r"[\\/]+", s)[-1] or s | |
| return last.lower() | |
| # ======================= E5 embedder ============================= | |
| class E5Embedder: | |
| def __init__(self, device=None): | |
| fallback_names = ["e5-large-ft_v4","multilingual-e5-large"] | |
| all_fallbacks = [pathlib.Path(E5_DIR).name] + fallback_names if E5_DIR else fallback_names | |
| self.model_path = _resolve_model_path(E5_DIR, all_fallbacks) | |
| self.model_name = model_name_key(self.model_path) | |
| self.device=device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| _log(f"[E5] Loading encoder from: {self.model_path} (device={self.device})") | |
| self.tok=AutoTokenizer.from_pretrained(self.model_path) | |
| self.mdl=AutoModel.from_pretrained(self.model_path, torch_dtype=torch.bfloat16 if self.device=="cuda" else None).to(self.device) # changed dtype to bf16 | |
| self.mdl.eval() | |
| def encode(self, texts: List[str], is_query=False, batch=64, progress_desc="E5 encode"): | |
| # Expects already-normalized texts | |
| pref="query: " if is_query else "passage: " | |
| # pref="" if is_query else "" | |
| out=[] | |
| n=len(texts) | |
| if n==0: return np.zeros((0,768), dtype=np.float32) | |
| total_batches = (n + batch - 1)//batch | |
| t0=time.time() | |
| for bi in range(total_batches): | |
| i = bi*batch | |
| chunk = texts[i:i+batch] | |
| enc=self.tok([pref+t.strip() for t in chunk], padding=True, truncation=True, max_length=512, return_tensors="pt").to(self.device) | |
| hs=self.mdl(**enc).last_hidden_state | |
| mask=enc["attention_mask"].unsqueeze(-1).expand(hs.size()).float() | |
| embs=(hs*mask).sum(1)/mask.sum(1).clamp(min=1e-9) | |
| embs=torch.nn.functional.normalize(embs, p=2, dim=1) | |
| out.append(embs.detach().cpu().to(dtype=torch.float32)) | |
| if not _EVAL_SILENT: | |
| if (bi+1)%50==0 or bi==0 or (bi+1)==total_batches: | |
| pct = 100.0*(bi+1)/total_batches | |
| elapsed = time.time()-t0 | |
| ips = (i+len(chunk))/max(elapsed,1e-6) | |
| print(f"[{progress_desc}] batch {bi+1}/{total_batches} ({pct:.1f}%) ~{ips:.0f} items/s") | |
| del enc, hs, embs | |
| if torch.cuda.is_available(): torch.cuda.empty_cache() | |
| return torch.cat(out, dim=0).numpy() | |
| # ======================= EmbeddingGemma embedder ===================== | |
| class GemmaEmbedder: | |
| """ | |
| Uses SentenceTransformer('google/embeddinggemma-300m'), BF16 if available. | |
| Returns L2-normalized 768-dim numpy arrays. | |
| No manual prompt prefixing; let SentenceTransformers handle prompting. | |
| """ | |
| def __init__(self, device=None): | |
| fallback_names = ["google/embeddinggemma-300m","embeddinggemma-300m"] | |
| all_fallbacks = [pathlib.Path(GEMMA_DIR).name] + fallback_names if GEMMA_DIR else fallback_names | |
| self.model_path = _resolve_model_path(GEMMA_DIR, all_fallbacks) | |
| self.model_name = model_name_key(self.model_path) | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| # dtype selection | |
| use_bf16 = bool(PREFER_BF16_GEMMA) | |
| if self.device == "cuda": | |
| try: | |
| use_bf16 = use_bf16 and torch.cuda.is_bf16_supported() | |
| except Exception: | |
| major, _ = torch.cuda.get_device_capability() | |
| use_bf16 = use_bf16 and (major >= 8) | |
| dtype = torch.bfloat16 if use_bf16 else torch.float16 | |
| _log(f"[Gemma] Loading encoder from: {self.model_path} (device={self.device}, dtype={'bf16' if use_bf16 else 'fp16'})") | |
| self.mdl = SentenceTransformer( | |
| self.model_path, | |
| device=self.device, | |
| model_kwargs={"torch_dtype": dtype}, | |
| ) | |
| # Tunable max tokens | |
| try: | |
| self.mdl.max_seq_length = int(GEMMA_MAX_TOK) | |
| except Exception: | |
| pass | |
| self.dim = 768 | |
| self.mdl.eval() | |
| def encode(self, texts: List[str], is_query=False, batch=64, progress_desc="Gemma encode", max_length: Optional[int]=None): | |
| if not texts: | |
| return np.zeros((0, self.dim), dtype=np.float32) | |
| # Per-call max length override | |
| old_len = getattr(self.mdl, "max_seq_length", None) | |
| if isinstance(max_length, int) and max_length > 0: | |
| try: | |
| self.mdl.max_seq_length = max_length | |
| except Exception: | |
| pass | |
| show = not _EVAL_SILENT | |
| # DO NOT manually add prompts. Prefer encode_query / encode_document when available. | |
| try: | |
| if is_query and hasattr(self.mdl, "encode_query"): | |
| embs = self.mdl.encode_query( | |
| texts, batch_size=batch, convert_to_numpy=True, | |
| normalize_embeddings=True, show_progress_bar=show | |
| ) | |
| elif (not is_query) and hasattr(self.mdl, "encode_document"): | |
| embs = self.mdl.encode_document( | |
| texts, batch_size=batch, convert_to_numpy=True, | |
| normalize_embeddings=True, show_progress_bar=show | |
| ) | |
| else: | |
| # Fallback: use encode with prompt=... if supported (avoids manual concatenation) | |
| prompt = (f"{'query: ' if is_query else 'passage: '}") | |
| try: | |
| embs = self.mdl.encode( | |
| texts, batch_size=batch, convert_to_numpy=True, | |
| normalize_embeddings=True, show_progress_bar=show, | |
| prompt=prompt | |
| ) | |
| except TypeError: | |
| # Last resort: plain encode (no prompt) | |
| embs = self.mdl.encode( | |
| texts, batch_size=batch, convert_to_numpy=True, | |
| normalize_embeddings=True, show_progress_bar=show | |
| ) | |
| finally: | |
| if old_len is not None: | |
| try: self.mdl.max_seq_length = old_len | |
| except Exception: pass | |
| embs = np.asarray(embs) | |
| if embs.ndim == 1: | |
| embs = embs[None, :] | |
| return embs.astype(np.float32) | |
| # ======================= BGE reranker ============================ | |
| class BGEReranker: | |
| def __init__(self, device=None): | |
| fallback_names = ["bge-reranker-hsrc-pairwise-rrf-V1.4","bge-v2-m3","bge-m3"] | |
| all_fallbacks = [pathlib.Path(BGE_DIR).name] + fallback_names if BGE_DIR else fallback_names | |
| self.model_path = _resolve_model_path(BGE_DIR, all_fallbacks) | |
| self.device=device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| _log(f"[BGE] Loading reranker from: {self.model_path} (device={self.device})") | |
| self.tok=AutoTokenizer.from_pretrained(self.model_path) | |
| self.mdl=AutoModelForSequenceClassification.from_pretrained( | |
| self.model_path, torch_dtype=torch.float16 if self.device=="cuda" else None, trust_remote_code=True | |
| ).to(self.device) | |
| self.mdl.eval() | |
| def score_pairs(self, q: str, passages: List[str], batch=32, max_len=512) -> List[float]: | |
| out=[] | |
| for i in range(0,len(passages), batch): | |
| enc=self.tok([q]*len(passages[i:i+batch]), passages[i:i+batch], | |
| truncation="only_second", max_length=max_len, padding=True, return_tensors="pt").to(self.device) | |
| logits=self.mdl(**enc).logits | |
| if logits.ndim==1: s=logits | |
| elif logits.shape[1]==1: s=logits.squeeze(-1) | |
| else: s=logits[:,1] | |
| out += s.detach().float().cpu().tolist() | |
| del enc, logits | |
| return [float(x) for x in out] | |
| # ======================== Hybrid Searcher ======================== | |
| class HybridSearcher: | |
| """ | |
| Stage-1 retrieval: WRRF(BM25, E5, Gemma) → candidate ids + WRRF scores. | |
| Then stage-2 reranking is done outside in predict(). | |
| """ | |
| def __init__(self, bm25: BM25Index, | |
| e5: E5Embedder, e5_corpus: np.ndarray, | |
| gemma: GemmaEmbedder, gemma_corpus: np.ndarray, | |
| id2text: Dict[str,str], id2norm: Dict[str,str]): | |
| self.bm25=bm25 | |
| self.e5=e5; self.e5_corpus=e5_corpus | |
| self.gemma=gemma; self.gemma_corpus=gemma_corpus | |
| self.id2text=id2text; self.id2norm=id2norm | |
| self._last_q: Optional[str] = None | |
| self._last_fused: List[Tuple[str, float]] = [] | |
| def _wrrf_fuse3(self, bm_ids: List[str], e5_ids: List[str], gm_ids: List[str], k=60, | |
| w_bm25=1.0, w_e5=1.0, w_gm=1.0) -> List[Tuple[str, float]]: | |
| rankA={pid:i for i,pid in enumerate(bm_ids)} | |
| rankB={pid:i for i,pid in enumerate(e5_ids)} | |
| rankC={pid:i for i,pid in enumerate(gm_ids)} | |
| scores=defaultdict(float) | |
| for pid, r in rankA.items(): scores[pid]+=w_bm25*(1.0/(k+r+1)) | |
| for pid, r in rankB.items(): scores[pid]+=w_e5 *(1.0/(k+r+1)) | |
| for pid, r in rankC.items(): scores[pid]+=w_gm *(1.0/(k+r+1)) | |
| return sorted(scores.items(), key=lambda x:-x[1]) | |
| def search(self, query: str, topk: int=200) -> List[Tuple[str, float]]: | |
| if self._last_q == query and self._last_fused: | |
| return self._last_fused[:topk] | |
| # BM25 list | |
| bm_ids = self.bm25.search(query, topk=TOP_BM25) | |
| # E5 list | |
| q_norm_e5 = norm_e5_query(query) # per-query normalization | |
| qe = self.e5.encode([q_norm_e5], is_query=True, batch=1, progress_desc="E5 query")[0] | |
| sims_e5 = (self.e5_corpus @ qe) # cosine (embeddings are L2-normalized) | |
| k2 = min(TOP_E5, len(sims_e5)) | |
| top_idx_e5 = np.argpartition(-sims_e5, k2-1)[:k2] | |
| top_idx_e5 = top_idx_e5[np.argsort(-sims_e5[top_idx_e5])] | |
| e5_ids = [self.bm25.doc_ids[i] for i in top_idx_e5] | |
| # Gemma list | |
| q_norm_gm = norm_gemma_query(query) # per-query normalization | |
| qg = self.gemma.encode([q_norm_gm], is_query=True, batch=1, progress_desc="Gemma query", max_length=GEMMA_MAX_TOK)[0] | |
| sims_gm = (self.gemma_corpus @ qg) # cosine (normalized) | |
| k3 = min(TOP_GEMMA, len(sims_gm)) | |
| top_idx_gm = np.argpartition(-sims_gm, k3-1)[:k3] | |
| top_idx_gm = top_idx_gm[np.argsort(-sims_gm[top_idx_gm])] | |
| gm_ids = [self.bm25.doc_ids[i] for i in top_idx_gm] | |
| fused_with_scores = self._wrrf_fuse3( | |
| bm_ids, e5_ids, gm_ids, k=RRF_K, | |
| w_bm25=WRRF_BM25_W, w_e5=WRRF_E5_W, w_gm=WRRF_GEMMA_W | |
| ) | |
| # seen=set(); out=[] | |
| # for pid, score in fused_with_scores: | |
| # key=self.id2norm.get(pid,"") | |
| # if key in seen: continue | |
| # seen.add(key) | |
| # out.append((pid, score)) | |
| # if len(out)>=topk: break | |
| out = fused_with_scores[:topk] | |
| self._last_q = query | |
| self._last_fused = out[:] | |
| return out | |
| # =========================== Globals =========================== | |
| _STATE = {} | |
| # =========================== Helpers =========================== | |
| def _sha1_ids(ids: List[str]) -> str: | |
| h = hashlib.sha1() | |
| for pid in ids: | |
| h.update(pid.encode("utf-8")); h.update(b"\n") | |
| return h.hexdigest() | |
| def _normalize_min_max(scores: List[float]) -> List[float]: | |
| """Scales a list of scores to the [0, 1] range.""" | |
| if not scores or len(scores) < 2: | |
| return [0.5] * len(scores) | |
| min_s, max_s = min(scores), max(scores) | |
| delta = max_s - min_s | |
| if delta < 1e-9: | |
| return [0.5] * len(scores) | |
| return [(s - min_s) / delta for s in scores] | |
| # =========================== API funcs ========================= | |
| def preprocess(corpus_dict: Dict[str, Dict]) -> Dict: | |
| ids, texts = [], [] | |
| bm25_norms = [] | |
| # -------- Per-paragraph normalization before indexing -------- | |
| e5_passage_norms = [] | |
| gm_passage_norms = [] | |
| for pid,obj in corpus_dict.items(): | |
| t = obj.get("passage") or obj.get("text") or "" | |
| pid = str(pid) | |
| ids.append(pid) | |
| texts.append(t) | |
| bm25_norms.append(norm_bm25(t)) # BM25 per paragraph | |
| e5_passage_norms.append(norm_e5_passage(t)) | |
| gm_passage_norms.append(norm_gemma_passage(t)) | |
| _log("="*60) | |
| _log(f"PREPROCESS: Building BM25 + E5 + Gemma embeddings + loading BGE") | |
| _log("="*60) | |
| # BM25 | |
| bm25 = BM25Index(k1=BM25_K1, b=BM25_B, logger=_log) | |
| bm25.build(ids, bm25_norms) | |
| # E5 encoder + caching | |
| e5 = E5Embedder() | |
| e5_mat = None | |
| cache_note_e5 = None | |
| if E5_EVAL_CACHE_DIR: | |
| os.makedirs(E5_EVAL_CACHE_DIR, exist_ok=True) | |
| meta_p = os.path.join(E5_EVAL_CACHE_DIR, "e5_meta.json") | |
| npy_p = os.path.join(E5_EVAL_CACHE_DIR, "e5_corpus.npy") | |
| sha = _sha1_ids(ids) | |
| if os.path.isfile(meta_p) and os.path.isfile(npy_p): | |
| try: | |
| with open(meta_p,"r",encoding="utf-8") as f: m=json.load(f) | |
| if m.get("sha1_ids")==sha and model_name_key(m.get("model_path",""))==e5.model_name and m.get("num_docs")==len(ids): | |
| _log(f"[E5] Loading cached corpus embeddings from {npy_p}") | |
| e5_mat = np.load(npy_p, mmap_mode=None) | |
| cache_note_e5 = "loaded" | |
| except Exception as e: _log(f"[E5] Cache read failed: {e} — recomputing.") | |
| if e5_mat is None: | |
| _log("[E5] Computing corpus embeddings...") | |
| t0=time.time() | |
| e5_mat = e5.encode(e5_passage_norms, is_query=False, batch=64, progress_desc="E5 corpus") | |
| _log(f"[E5] Done in {time.time()-t0:.1f}s — shape={e5_mat.shape}") | |
| if E5_EVAL_CACHE_DIR: | |
| try: | |
| np.save(os.path.join(E5_EVAL_CACHE_DIR,"e5_corpus.npy"), e5_mat) | |
| meta = {"sha1_ids": _sha1_ids(ids), "num_docs": len(ids), "model_path": e5.model_path, "dim": int(e5_mat.shape[1]), "created": time.time()} | |
| with open(os.path.join(E5_EVAL_CACHE_DIR,"e5_meta.json"),"w",encoding="utf-8") as f: json.dump(meta,f,ensure_ascii=False, indent=2) | |
| cache_note_e5 = "saved" | |
| _log(f"[E5] Saved cache to {E5_EVAL_CACHE_DIR}") | |
| except Exception as e: _log(f"[E5] Cache save failed: {e}") | |
| # Gemma encoder + caching | |
| gemma = GemmaEmbedder() | |
| gemma_mat = None | |
| cache_note_gm = None | |
| if GEMMA_EVAL_CACHE_DIR: | |
| os.makedirs(GEMMA_EVAL_CACHE_DIR, exist_ok=True) | |
| meta_p_gm = os.path.join(GEMMA_EVAL_CACHE_DIR, "gemma_meta.json") | |
| npy_p_gm = os.path.join(GEMMA_EVAL_CACHE_DIR, "gemma_corpus.npy") | |
| sha = _sha1_ids(ids) | |
| if os.path.isfile(meta_p_gm) and os.path.isfile(npy_p_gm): | |
| try: | |
| with open(meta_p_gm,"r",encoding="utf-8") as f: m=json.load(f) | |
| if m.get("sha1_ids")==sha and model_name_key(m.get("model_path",""))==gemma.model_name and m.get("num_docs")==len(ids): | |
| _log(f"[Gemma] Loading cached corpus embeddings from {npy_p_gm}") | |
| gemma_mat = np.load(npy_p_gm, mmap_mode=None) | |
| cache_note_gm = "loaded" | |
| except Exception as e: _log(f"[Gemma] Cache read failed: {e} — recomputing.") | |
| if gemma_mat is None: | |
| _log("[Gemma] Computing corpus embeddings...") | |
| t0=time.time() | |
| gemma_mat = gemma.encode(gm_passage_norms, is_query=False, batch=64, progress_desc="Gemma corpus", max_length=GEMMA_MAX_TOK) | |
| _log(f"[Gemma] Done in {time.time()-t0:.1f}s — shape={gemma_mat.shape}") | |
| if GEMMA_EVAL_CACHE_DIR: | |
| try: | |
| np.save(os.path.join(GEMMA_EVAL_CACHE_DIR,"gemma_corpus.npy"), gemma_mat) | |
| meta_gm = {"sha1_ids": _sha1_ids(ids), "num_docs": len(ids), "model_path": gemma.model_path, "dim": int(gemma_mat.shape[1]), "created": time.time()} | |
| with open(os.path.join(GEMMA_EVAL_CACHE_DIR,"gemma_meta.json"),"w",encoding="utf-8") as f: json.dump(meta_gm,f,ensure_ascii=False, indent=2) | |
| cache_note_gm = "saved" | |
| _log(f"[Gemma] Saved cache to {GEMMA_EVAL_CACHE_DIR}") | |
| except Exception as e: _log(f"[Gemma] Cache save failed: {e}") | |
| # Reranker | |
| rr = BGEReranker() | |
| id2text = dict(zip(ids,texts)) | |
| id2norm = dict(zip(ids,bm25_norms)) | |
| hybrid = HybridSearcher(bm25, e5, e5_mat, gemma, gemma_mat, id2text, id2norm) | |
| _STATE.update({ | |
| "bm25": bm25, "id2text": id2text, "id2norm": id2norm, | |
| "e5": e5, "e5_corpus": e5_mat, | |
| "gemma": gemma, "gemma_corpus": gemma_mat, | |
| "reranker": rr, "hybrid": hybrid | |
| }) | |
| reranker_params = { | |
| "CE_POOL": CE_POOL, "CE_MAXLEN": CE_MAXLEN, "CE_BATCH": CE_BATCH, | |
| "FINAL_SCORE_BGE_WEIGHT": FINAL_SCORE_BGE_WEIGHT | |
| } | |
| meta = { | |
| "stage1_name": "WRRF(BM25, E5, Gemma)", | |
| "stage1_params": { | |
| "TOP_BM25": TOP_BM25, "TOP_E5": TOP_E5, "TOP_GEMMA": TOP_GEMMA, "RRF_K": RRF_K, | |
| "WRRF_WEIGHTS": {"bm25": WRRF_BM25_W, "e5": WRRF_E5_W, "gemma": WRRF_GEMMA_W} | |
| }, | |
| "reranker_name": "BGE + Hybrid Fusion (Conditional Boost)", | |
| "reranker_params": reranker_params, | |
| "candidate_pool_cap": CE_POOL, | |
| "stage1_search_key": "bm25", | |
| "bm25_backend": getattr(bm25, "_backend_name", "unknown"), | |
| "e5_model_path": e5.model_path, | |
| "gemma_model_path": gemma.model_path, | |
| "bge_model_path": rr.model_path, | |
| "cache_dir_e5": E5_EVAL_CACHE_DIR or None, | |
| "cache_dir_gemma": GEMMA_EVAL_CACHE_DIR or None, | |
| "e5_cache": cache_note_e5 or ("unused" if not E5_EVAL_CACHE_DIR else "miss"), | |
| "gemma_cache": cache_note_gm or ("unused" if not GEMMA_EVAL_CACHE_DIR else "miss"), | |
| } | |
| _log("✓ PREPROCESS complete.") | |
| return { | |
| "bm25": hybrid, "id2text": id2text, "id2norm": id2norm, | |
| "reranker": rr, "num_documents": len(ids), "_eval": meta | |
| } | |
| def predict(query: Dict, pre: Dict): | |
| q = query.get("query","") | |
| if not q: return [] | |
| hyb = _STATE.get("hybrid") or pre["bm25"] | |
| rr = _STATE.get("reranker") or pre["reranker"] | |
| id2text = _STATE.get("id2text") or pre["id2text"] | |
| # Stage-1: WRRF retrieval | |
| cand_id_scores = hyb.search(q, topk=CE_POOL) | |
| if not cand_id_scores: return [] | |
| cand_ids, rrf_scores = zip(*cand_id_scores) | |
| passages = [id2text[pid] for pid in cand_ids] | |
| # Stage-2: BGE reranker (with its own normalizers) | |
| q_norm_bge = norm_bge_query(q) | |
| passages_norm_bge = [norm_bge_passage(p) for p in passages] | |
| bge_scores = rr.score_pairs(q_norm_bge, passages_norm_bge, batch=CE_BATCH, max_len=CE_MAXLEN) | |
| # Stage-3: Normalize and combine (conditional boost) | |
| norm_bge = _normalize_min_max(bge_scores) | |
| norm_rrf = _normalize_min_max(list(rrf_scores)) | |
| final_scores = [] | |
| w_rrf = 1.0 - FINAL_SCORE_BGE_WEIGHT | |
| for bge_score, rrf_score in zip(norm_bge, norm_rrf): | |
| boost = w_rrf * rrf_score * (1.0 - bge_score) | |
| final_scores.append(bge_score + boost) | |
| # Final output | |
| out = [{"paragraph_uuid": pid, "score": float(s)} | |
| for pid, s in sorted(zip(cand_ids, final_scores), key=lambda x: -x[1])] | |
| return out | |