RobertoBarrosoLuque commited on
Commit
75361de
·
1 Parent(s): 4cba650

Add reranking

Browse files
src/app.py CHANGED
@@ -2,14 +2,17 @@ import gradio as gr
2
  import time
3
  from typing import List, Dict, Tuple, Callable
4
  from pathlib import Path
5
- import os
6
  from config import (
7
  GRADIO_THEME,
8
  CUSTOM_CSS,
9
  EXAMPLE_QUERIES_BY_CATEGORY,
10
  )
11
  from src.search.bm25_lexical_search import search_bm25
12
- from src.search.vector_search import search_vector, search_vector_with_expansion
 
 
 
 
13
  from src.data_prep.data_prep import load_clean_amazon_product_data
14
  from src.constants.code_snippets import (
15
  CODE_STAGE_1,
@@ -63,21 +66,26 @@ def format_results(results: List[Dict], stage_name: str, metrics: Dict) -> str:
63
  Args:
64
  results: List of dicts with keys: product_name, description, main_category, secondary_category, score
65
  stage_name: Name of the search stage
66
- metrics: Dict with keys: semantic_match, diversity, latency_ms
67
  """
68
  html_parts = [
69
  f"## 🔍 {stage_name}\n\n",
70
  f"""
71
- <div style="display: flex; gap: 20px; margin-bottom: 28px;">
 
 
 
 
 
72
  <div class="metric-box" style="flex: 1;">
73
- <div style="color: #6720FF; font-size: 0.9em; font-weight: 600; margin-bottom: 6px; letter-spacing: 0.5px;">SEMANTIC MATCH</div>
74
- <div style="font-size: 2.2em; font-weight: 700; color: #1E293B;">{metrics['semantic_match']:.3f}</div>
75
- <div style="color: #64748B; font-size: 0.8em; margin-top: 4px;">Higher is better</div>
76
  </div>
77
  <div class="metric-box" style="flex: 1;">
78
- <div style="color: #6720FF; font-size: 0.9em; font-weight: 600; margin-bottom: 6px; letter-spacing: 0.5px;">LATENCY</div>
79
- <div style="font-size: 2.2em; font-weight: 700; color: #1E293B;">{metrics['latency_ms']}<span style="font-size: 0.45em; color: #64748B; font-weight: 400;">ms</span></div>
80
- <div style="color: #64748B; font-size: 0.8em; margin-top: 4px;">Response time</div>
81
  </div>
82
  </div>
83
  """,
@@ -114,14 +122,40 @@ def get_average_score(results: List[Dict]) -> float:
114
  return sum(r["score"] for r in results) / len(results) if results else 0
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def search_stage_1(query: str) -> Tuple[str, Dict]:
118
  """Stage 1: Baseline BM25 keyword search."""
119
  results, latency = run_search_function_and_time(query, search_bm25)
120
- avg_score = get_average_score(results)
121
- semantic_match = min(1.0, avg_score / len(results))
 
122
 
123
  metrics = {
124
- "semantic_match": semantic_match,
 
125
  "latency_ms": latency,
126
  }
127
  print(f"Searched BM25 for {query} in {latency}ms")
@@ -132,10 +166,13 @@ def search_stage_1(query: str) -> Tuple[str, Dict]:
132
  def search_stage_2(query: str) -> Tuple[str, Dict]:
133
  """Stage 2: Vector Embeddings using FAISS."""
134
  results, latency = run_search_function_and_time(query, search_vector)
135
- semantic_match = get_average_score(results)
 
 
136
 
137
  metrics = {
138
- "semantic_match": semantic_match,
 
139
  "latency_ms": latency,
140
  }
141
  print(f"Searched vector embeddings for '{query}' in {latency}ms")
@@ -145,12 +182,14 @@ def search_stage_2(query: str) -> Tuple[str, Dict]:
145
 
146
  def search_stage_3(query: str) -> Tuple[str, Dict]:
147
  """Stage 3: Query Expansion + Vector Embeddings."""
148
-
149
  results, latency = run_search_function_and_time(query, search_vector_with_expansion)
150
- semantic_match = get_average_score(results)
 
 
151
 
152
  metrics = {
153
- "semantic_match": semantic_match,
 
154
  "latency_ms": latency,
155
  }
156
 
@@ -158,29 +197,19 @@ def search_stage_3(query: str) -> Tuple[str, Dict]:
158
 
159
 
160
  def search_stage_4(query: str) -> Tuple[str, Dict]:
161
- """Stage 4: BM25 + Embeddings + Query Expansion + LLM Reranking."""
162
- start_time = time.time()
163
-
164
- # Placeholder: Simulated reranking with correct format
165
- results = [
166
- {
167
- "product_name": product["title"],
168
- "description": product["description"],
169
- "main_category": product["category"],
170
- "secondary_category": "Placeholder",
171
- "score": 0.85 + (idx * 0.025),
172
- }
173
- for idx, product in enumerate(SAMPLE_PRODUCTS[:5])
174
- ]
175
 
176
- latency = int((time.time() - start_time) * 1000)
 
177
 
178
  metrics = {
179
- "semantic_match": 0.88,
180
- "latency_ms": max(200, latency),
 
181
  }
182
 
183
- return format_results(results, "Stage 4: + LLM Reranking", metrics), metrics
184
 
185
 
186
  def search_all_stages(query: str) -> Tuple[str, str, str, str, str]:
@@ -210,26 +239,37 @@ def generate_comparison_table(all_metrics: List[Dict]) -> str:
210
 
211
  # Build markdown table
212
  html = "## Stage-by-Stage Comparison\n\n"
213
- html += "| Stage | Semantic Match | Latency (ms) |\n"
214
- html += "|-------|----------------|--------------|\n"
215
 
216
  for name, metrics in zip(stage_names, all_metrics):
217
- html += f"| **{name}** | {metrics['semantic_match']:.3f} | {metrics['latency_ms']} |\n"
218
 
219
- # Calculate improvements
220
- semantic_improvement = (
221
  (
222
- (all_metrics[3]["semantic_match"] - all_metrics[0]["semantic_match"])
223
- / all_metrics[0]["semantic_match"]
224
  * 100
225
  )
226
- if all_metrics[0]["semantic_match"] > 0
 
 
 
 
 
 
 
 
 
 
227
  else 0
228
  )
229
 
230
  html += "\n---\n\n"
231
  html += "## Key Insights\n\n"
232
- html += f"- **Semantic Match** improves by **{semantic_improvement:.0f}%** from baseline to final stage\n"
 
233
  html += f"- **Latency** stays under **{max(m['latency_ms'] for m in all_metrics)}ms** maintaining fast performance\n"
234
  html += "- Each stage progressively enhances search relevance while keeping response times low\n"
235
  html += "- Vector embeddings provide the biggest jump in semantic understanding\n"
@@ -363,17 +403,7 @@ with gr.Blocks(
363
  scale=3,
364
  elem_classes="search-box",
365
  )
366
- with gr.Column(scale=1):
367
- val = os.getenv("FIREWORKS_API_KEY", "") # pragma: allowlist secret
368
- api_key_value = gr.Textbox( # pragma: allowlist secret
369
- label="API Key",
370
- type="password",
371
- placeholder="Enter your Fireworks AI API key",
372
- value=val,
373
- container=True,
374
- elem_classes="compact-input",
375
- )
376
- # Clean example query selector
377
  with gr.Row():
378
  gr.Markdown(
379
  "**Try Example Queries:** Select a category and specificity level to auto-load an example"
 
2
  import time
3
  from typing import List, Dict, Tuple, Callable
4
  from pathlib import Path
 
5
  from config import (
6
  GRADIO_THEME,
7
  CUSTOM_CSS,
8
  EXAMPLE_QUERIES_BY_CATEGORY,
9
  )
10
  from src.search.bm25_lexical_search import search_bm25
11
+ from src.search.vector_search import (
12
+ search_vector,
13
+ search_vector_with_expansion,
14
+ search_vector_with_reranking,
15
+ )
16
  from src.data_prep.data_prep import load_clean_amazon_product_data
17
  from src.constants.code_snippets import (
18
  CODE_STAGE_1,
 
66
  Args:
67
  results: List of dicts with keys: product_name, description, main_category, secondary_category, score
68
  stage_name: Name of the search stage
69
+ metrics: Dict with keys: top1_score, top5_avg, latency_ms
70
  """
71
  html_parts = [
72
  f"## 🔍 {stage_name}\n\n",
73
  f"""
74
+ <div style="display: flex; gap: 16px; margin-bottom: 28px;">
75
+ <div class="metric-box" style="flex: 1;">
76
+ <div style="color: #6720FF; font-size: 0.85em; font-weight: 600; margin-bottom: 6px; letter-spacing: 0.5px;">TOP-1 SCORE</div>
77
+ <div style="font-size: 2em; font-weight: 700; color: #1E293B;">{metrics['top1_score']:.3f}</div>
78
+ <div style="color: #64748B; font-size: 0.75em; margin-top: 4px;">Best result</div>
79
+ </div>
80
  <div class="metric-box" style="flex: 1;">
81
+ <div style="color: #6720FF; font-size: 0.85em; font-weight: 600; margin-bottom: 6px; letter-spacing: 0.5px;">TOP-5 AVG</div>
82
+ <div style="font-size: 2em; font-weight: 700; color: #1E293B;">{metrics['top5_avg']:.3f}</div>
83
+ <div style="color: #64748B; font-size: 0.75em; margin-top: 4px;">Overall quality</div>
84
  </div>
85
  <div class="metric-box" style="flex: 1;">
86
+ <div style="color: #6720FF; font-size: 0.85em; font-weight: 600; margin-bottom: 6px; letter-spacing: 0.5px;">LATENCY</div>
87
+ <div style="font-size: 2em; font-weight: 700; color: #1E293B;">{metrics['latency_ms']}<span style="font-size: 0.45em; color: #64748B; font-weight: 400;">ms</span></div>
88
+ <div style="color: #64748B; font-size: 0.75em; margin-top: 4px;">Response time</div>
89
  </div>
90
  </div>
91
  """,
 
122
  return sum(r["score"] for r in results) / len(results) if results else 0
123
 
124
 
125
+ def get_weighted_score(results: List[Dict]) -> float:
126
+ """
127
+ Calculate position-weighted average score.
128
+
129
+ Top positions get higher weight (5x for #1, 4x for #2, etc.)
130
+ This rewards ranking quality - putting best results at the top.
131
+
132
+ Args:
133
+ results: List of search results with 'score' field
134
+
135
+ Returns:
136
+ Weighted average score (0-1 scale)
137
+ """
138
+ if not results:
139
+ return 0.0
140
+
141
+ weights = [5, 4, 3, 2, 1]
142
+ total_weight = sum(weights)
143
+
144
+ weighted_sum = sum((weights[i] * r["score"]) for i, r in enumerate(results[:5]))
145
+
146
+ return weighted_sum / total_weight
147
+
148
+
149
  def search_stage_1(query: str) -> Tuple[str, Dict]:
150
  """Stage 1: Baseline BM25 keyword search."""
151
  results, latency = run_search_function_and_time(query, search_bm25)
152
+
153
+ top1_score = results[0]["score"] / 5.0 if results else 0.0 # Normalize BM25 scores
154
+ top5_avg = get_average_score(results) / 5.0 if results else 0.0
155
 
156
  metrics = {
157
+ "top1_score": min(1.0, top1_score),
158
+ "top5_avg": min(1.0, top5_avg),
159
  "latency_ms": latency,
160
  }
161
  print(f"Searched BM25 for {query} in {latency}ms")
 
166
  def search_stage_2(query: str) -> Tuple[str, Dict]:
167
  """Stage 2: Vector Embeddings using FAISS."""
168
  results, latency = run_search_function_and_time(query, search_vector)
169
+
170
+ top1_score = results[0]["score"] if results else 0.0
171
+ top5_avg = get_average_score(results)
172
 
173
  metrics = {
174
+ "top1_score": top1_score,
175
+ "top5_avg": top5_avg,
176
  "latency_ms": latency,
177
  }
178
  print(f"Searched vector embeddings for '{query}' in {latency}ms")
 
182
 
183
  def search_stage_3(query: str) -> Tuple[str, Dict]:
184
  """Stage 3: Query Expansion + Vector Embeddings."""
 
185
  results, latency = run_search_function_and_time(query, search_vector_with_expansion)
186
+
187
+ top1_score = results[0]["score"] if results else 0.0
188
+ top5_avg = get_average_score(results)
189
 
190
  metrics = {
191
+ "top1_score": top1_score,
192
+ "top5_avg": top5_avg,
193
  "latency_ms": latency,
194
  }
195
 
 
197
 
198
 
199
  def search_stage_4(query: str) -> Tuple[str, Dict]:
200
+ """Stage 4: Query Expansion + Vector Embeddings + Reranking."""
201
+ results, latency = run_search_function_and_time(query, search_vector_with_reranking)
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ top1_score = results[0]["score"] if results else 0.0
204
+ top5_avg = get_average_score(results)
205
 
206
  metrics = {
207
+ "top1_score": top1_score,
208
+ "top5_avg": top5_avg,
209
+ "latency_ms": latency,
210
  }
211
 
212
+ return format_results(results, "Stage 4: Reranking", metrics), metrics
213
 
214
 
215
  def search_all_stages(query: str) -> Tuple[str, str, str, str, str]:
 
239
 
240
  # Build markdown table
241
  html = "## Stage-by-Stage Comparison\n\n"
242
+ html += "| Stage | Top-1 Score | Top-5 Avg | Latency (ms) |\n"
243
+ html += "|-------|-------------|-----------|-------------|\n"
244
 
245
  for name, metrics in zip(stage_names, all_metrics):
246
+ html += f"| **{name}** | {metrics['top1_score']:.3f} | {metrics['top5_avg']:.3f} | {metrics['latency_ms']} |\n"
247
 
248
+ # Calculate improvements based on top-5 average
249
+ top5_improvement = (
250
  (
251
+ (all_metrics[3]["top5_avg"] - all_metrics[0]["top5_avg"])
252
+ / all_metrics[0]["top5_avg"]
253
  * 100
254
  )
255
+ if all_metrics[0]["top5_avg"] > 0
256
+ else 0
257
+ )
258
+
259
+ top1_improvement = (
260
+ (
261
+ (all_metrics[3]["top1_score"] - all_metrics[0]["top1_score"])
262
+ / all_metrics[0]["top1_score"]
263
+ * 100
264
+ )
265
+ if all_metrics[0]["top1_score"] > 0
266
  else 0
267
  )
268
 
269
  html += "\n---\n\n"
270
  html += "## Key Insights\n\n"
271
+ html += f"- **Top-1 Score** improves by **{top1_improvement:.0f}%** from baseline to final stage\n"
272
+ html += f"- **Top-5 Average** improves by **{top5_improvement:.0f}%** from baseline to final stage\n"
273
  html += f"- **Latency** stays under **{max(m['latency_ms'] for m in all_metrics)}ms** maintaining fast performance\n"
274
  html += "- Each stage progressively enhances search relevance while keeping response times low\n"
275
  html += "- Vector embeddings provide the biggest jump in semantic understanding\n"
 
403
  scale=3,
404
  elem_classes="search-box",
405
  )
406
+
 
 
 
 
 
 
 
 
 
 
407
  with gr.Row():
408
  gr.Markdown(
409
  "**Try Example Queries:** Select a category and specificity level to auto-load an example"
src/config.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  # Fireworks AI Model Configuration
4
  EMBEDDING_MODEL = "accounts/fireworks/models/qwen3-embedding-8b"
5
  LLM_MODEL = "accounts/fireworks/models/qwen3-8b"
6
- RERANKER_MODEL = "fireworks/qwen3-reranker-8b"
7
 
8
  GRADIO_THEME = gr.themes.Base(
9
  primary_hue=gr.themes.colors.purple,
 
3
  # Fireworks AI Model Configuration
4
  EMBEDDING_MODEL = "accounts/fireworks/models/qwen3-embedding-8b"
5
  LLM_MODEL = "accounts/fireworks/models/qwen3-8b"
6
+ RERANKER_MODEL = "accounts/fireworks/models/qwen3-reranker-8b"
7
 
8
  GRADIO_THEME = gr.themes.Base(
9
  primary_hue=gr.themes.colors.purple,
src/fireworks/inference.py CHANGED
@@ -2,14 +2,18 @@ import os
2
  import yaml
3
  from openai import OpenAI
4
  from dotenv import load_dotenv
5
- from typing import List
6
  from pathlib import Path
7
- from src.config import EMBEDDING_MODEL, LLM_MODEL
 
8
 
9
  load_dotenv()
10
 
11
  _FILE_PATH = Path(__file__).parents[2]
12
 
 
 
 
13
 
14
  def load_prompt_library():
15
  """Load prompts from YAML configuration."""
@@ -17,15 +21,15 @@ def load_prompt_library():
17
  return yaml.safe_load(f)
18
 
19
 
20
- def create_client(api_key: str = None) -> OpenAI:
21
  """
22
  Create client for FW inference
23
  """
24
- api_key = os.getenv("FIREWORKS_API_KEY", api_key)
25
  assert api_key is not None, "FIREWORKS_API_KEY not found in environment variables"
26
  return OpenAI(
27
  api_key=api_key,
28
- base_url="https://api.fireworks.ai/inference/v1",
29
  )
30
 
31
 
@@ -75,3 +79,47 @@ def expand_query(query: str) -> str:
75
 
76
  expanded = response.choices[0].message.content.strip()
77
  return expanded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import yaml
3
  from openai import OpenAI
4
  from dotenv import load_dotenv
5
+ from typing import List, Dict
6
  from pathlib import Path
7
+ import requests
8
+ from src.config import EMBEDDING_MODEL, LLM_MODEL, RERANKER_MODEL
9
 
10
  load_dotenv()
11
 
12
  _FILE_PATH = Path(__file__).parents[2]
13
 
14
+ RERANK_URL = "https://api.fireworks.ai/inference/v1/rerank"
15
+ INFERENCE_URL = "https://api.fireworks.ai/inference/v1"
16
+
17
 
18
  def load_prompt_library():
19
  """Load prompts from YAML configuration."""
 
21
  return yaml.safe_load(f)
22
 
23
 
24
+ def create_client() -> OpenAI:
25
  """
26
  Create client for FW inference
27
  """
28
+ api_key = os.getenv("FIREWORKS_API_KEY")
29
  assert api_key is not None, "FIREWORKS_API_KEY not found in environment variables"
30
  return OpenAI(
31
  api_key=api_key,
32
+ base_url=INFERENCE_URL,
33
  )
34
 
35
 
 
79
 
80
  expanded = response.choices[0].message.content.strip()
81
  return expanded
82
+
83
+
84
+ def rerank_results(query: str, results: List[Dict], top_n: int = 5) -> List[Dict]:
85
+ """
86
+ Rerank search results using Fireworks AI reranker model.
87
+
88
+ Takes search results and reranks them based on relevance to the query
89
+ using a specialized reranking model that considers cross-attention between
90
+ query and documents.
91
+
92
+ Args:
93
+ query: Original search query
94
+ results: List of dictionaries containing product information and scores
95
+ top_n: Number of top results to return after reranking (default: 5)
96
+
97
+ Returns:
98
+ List of dictionaries containing reranked product information with updated scores
99
+ """
100
+ # Prepare documents as text for reranker (product name + description)
101
+ documents = [f"{r['product_name']}. {r['description']}" for r in results]
102
+
103
+ payload = {
104
+ "model": RERANKER_MODEL,
105
+ "query": query,
106
+ "documents": documents,
107
+ "top_n": top_n,
108
+ "return_documents": False,
109
+ }
110
+
111
+ headers = {
112
+ "Authorization": f"Bearer {os.getenv('FIREWORKS_API_KEY')}",
113
+ "Content-Type": "application/json",
114
+ }
115
+
116
+ response = requests.post(RERANK_URL, json=payload, headers=headers)
117
+ rerank_data = response.json()
118
+
119
+ # Map reranked results back to original product data
120
+ reranked_results = []
121
+ for item in rerank_data.get("data", []):
122
+ idx = item["index"]
123
+ reranked_results.append({**results[idx], "score": item["relevance_score"]})
124
+
125
+ return reranked_results
src/search/vector_search.py CHANGED
@@ -2,13 +2,13 @@ import numpy as np
2
  import faiss
3
  from typing import List, Dict
4
  from pathlib import Path
5
- from src.fireworks.inference import get_embedding, expand_query
6
  from constants.constants import FAISS_INDEX, PRODUCTS_DF
7
 
8
  _FILE_PATH = Path(__file__).parents[2]
9
 
10
 
11
- def search_vector(query: str, top_k: int = 10) -> List[Dict[str, any]]:
12
  """
13
  Search products using vector embeddings and FAISS for semantic search.
14
 
@@ -46,7 +46,7 @@ def search_vector(query: str, top_k: int = 10) -> List[Dict[str, any]]:
46
  ]
47
 
48
 
49
- def search_vector_with_expansion(query: str, top_k: int = 10) -> List[Dict[str, any]]:
50
  """
51
  Search products using vector embeddings and FAISS for semantic search with query expansion.
52
 
@@ -64,3 +64,27 @@ def search_vector_with_expansion(query: str, top_k: int = 10) -> List[Dict[str,
64
  print(f"Original: {query}")
65
  print(f"Expanded: {expanded_query}")
66
  return search_vector(expanded_query, top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import faiss
3
  from typing import List, Dict
4
  from pathlib import Path
5
+ from src.fireworks.inference import get_embedding, expand_query, rerank_results
6
  from constants.constants import FAISS_INDEX, PRODUCTS_DF
7
 
8
  _FILE_PATH = Path(__file__).parents[2]
9
 
10
 
11
+ def search_vector(query: str, top_k: int = 5) -> List[Dict[str, any]]:
12
  """
13
  Search products using vector embeddings and FAISS for semantic search.
14
 
 
46
  ]
47
 
48
 
49
+ def search_vector_with_expansion(query: str, top_k: int = 5) -> List[Dict[str, any]]:
50
  """
51
  Search products using vector embeddings and FAISS for semantic search with query expansion.
52
 
 
64
  print(f"Original: {query}")
65
  print(f"Expanded: {expanded_query}")
66
  return search_vector(expanded_query, top_k)
67
+
68
+
69
+ def search_vector_with_reranking(query: str, top_k: int = 5) -> List[Dict[str, any]]:
70
+ """
71
+ Search products using vector embeddings and FAISS for semantic search with reranking.
72
+
73
+ This is Stage 4: semantic search using vector embeddings to understand
74
+ query meaning and intent beyond exact keyword matching, with reranking.
75
+
76
+ Args:
77
+ query: Search query string
78
+ top_k: Number of top results to return (default: 10)
79
+
80
+ Returns:
81
+ List of dictionaries containing product information with preserved cosine scores
82
+ """
83
+ results = search_vector_with_expansion(query, top_k)
84
+ cosine_scores = {r["product_name"]: r["score"] for r in results}
85
+ reranked_results = rerank_results(query=query, results=results)
86
+
87
+ for r in reranked_results:
88
+ r["score"] = cosine_scores[r["product_name"]]
89
+
90
+ return reranked_results