Spaces:
Sleeping
Sleeping
RobertoBarrosoLuque
commited on
Commit
·
75361de
1
Parent(s):
4cba650
Add reranking
Browse files- src/app.py +86 -56
- src/config.py +1 -1
- src/fireworks/inference.py +53 -5
- src/search/vector_search.py +27 -3
src/app.py
CHANGED
|
@@ -2,14 +2,17 @@ import gradio as gr
|
|
| 2 |
import time
|
| 3 |
from typing import List, Dict, Tuple, Callable
|
| 4 |
from pathlib import Path
|
| 5 |
-
import os
|
| 6 |
from config import (
|
| 7 |
GRADIO_THEME,
|
| 8 |
CUSTOM_CSS,
|
| 9 |
EXAMPLE_QUERIES_BY_CATEGORY,
|
| 10 |
)
|
| 11 |
from src.search.bm25_lexical_search import search_bm25
|
| 12 |
-
from src.search.vector_search import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from src.data_prep.data_prep import load_clean_amazon_product_data
|
| 14 |
from src.constants.code_snippets import (
|
| 15 |
CODE_STAGE_1,
|
|
@@ -63,21 +66,26 @@ def format_results(results: List[Dict], stage_name: str, metrics: Dict) -> str:
|
|
| 63 |
Args:
|
| 64 |
results: List of dicts with keys: product_name, description, main_category, secondary_category, score
|
| 65 |
stage_name: Name of the search stage
|
| 66 |
-
metrics: Dict with keys:
|
| 67 |
"""
|
| 68 |
html_parts = [
|
| 69 |
f"## 🔍 {stage_name}\n\n",
|
| 70 |
f"""
|
| 71 |
-
<div style="display: flex; gap:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
<div class="metric-box" style="flex: 1;">
|
| 73 |
-
<div style="color: #6720FF; font-size: 0.
|
| 74 |
-
<div style="font-size:
|
| 75 |
-
<div style="color: #64748B; font-size: 0.
|
| 76 |
</div>
|
| 77 |
<div class="metric-box" style="flex: 1;">
|
| 78 |
-
<div style="color: #6720FF; font-size: 0.
|
| 79 |
-
<div style="font-size:
|
| 80 |
-
<div style="color: #64748B; font-size: 0.
|
| 81 |
</div>
|
| 82 |
</div>
|
| 83 |
""",
|
|
@@ -114,14 +122,40 @@ def get_average_score(results: List[Dict]) -> float:
|
|
| 114 |
return sum(r["score"] for r in results) / len(results) if results else 0
|
| 115 |
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
def search_stage_1(query: str) -> Tuple[str, Dict]:
|
| 118 |
"""Stage 1: Baseline BM25 keyword search."""
|
| 119 |
results, latency = run_search_function_and_time(query, search_bm25)
|
| 120 |
-
|
| 121 |
-
|
|
|
|
| 122 |
|
| 123 |
metrics = {
|
| 124 |
-
"
|
|
|
|
| 125 |
"latency_ms": latency,
|
| 126 |
}
|
| 127 |
print(f"Searched BM25 for {query} in {latency}ms")
|
|
@@ -132,10 +166,13 @@ def search_stage_1(query: str) -> Tuple[str, Dict]:
|
|
| 132 |
def search_stage_2(query: str) -> Tuple[str, Dict]:
|
| 133 |
"""Stage 2: Vector Embeddings using FAISS."""
|
| 134 |
results, latency = run_search_function_and_time(query, search_vector)
|
| 135 |
-
|
|
|
|
|
|
|
| 136 |
|
| 137 |
metrics = {
|
| 138 |
-
"
|
|
|
|
| 139 |
"latency_ms": latency,
|
| 140 |
}
|
| 141 |
print(f"Searched vector embeddings for '{query}' in {latency}ms")
|
|
@@ -145,12 +182,14 @@ def search_stage_2(query: str) -> Tuple[str, Dict]:
|
|
| 145 |
|
| 146 |
def search_stage_3(query: str) -> Tuple[str, Dict]:
|
| 147 |
"""Stage 3: Query Expansion + Vector Embeddings."""
|
| 148 |
-
|
| 149 |
results, latency = run_search_function_and_time(query, search_vector_with_expansion)
|
| 150 |
-
|
|
|
|
|
|
|
| 151 |
|
| 152 |
metrics = {
|
| 153 |
-
"
|
|
|
|
| 154 |
"latency_ms": latency,
|
| 155 |
}
|
| 156 |
|
|
@@ -158,29 +197,19 @@ def search_stage_3(query: str) -> Tuple[str, Dict]:
|
|
| 158 |
|
| 159 |
|
| 160 |
def search_stage_4(query: str) -> Tuple[str, Dict]:
|
| 161 |
-
"""Stage 4:
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
# Placeholder: Simulated reranking with correct format
|
| 165 |
-
results = [
|
| 166 |
-
{
|
| 167 |
-
"product_name": product["title"],
|
| 168 |
-
"description": product["description"],
|
| 169 |
-
"main_category": product["category"],
|
| 170 |
-
"secondary_category": "Placeholder",
|
| 171 |
-
"score": 0.85 + (idx * 0.025),
|
| 172 |
-
}
|
| 173 |
-
for idx, product in enumerate(SAMPLE_PRODUCTS[:5])
|
| 174 |
-
]
|
| 175 |
|
| 176 |
-
|
|
|
|
| 177 |
|
| 178 |
metrics = {
|
| 179 |
-
"
|
| 180 |
-
"
|
|
|
|
| 181 |
}
|
| 182 |
|
| 183 |
-
return format_results(results, "Stage 4:
|
| 184 |
|
| 185 |
|
| 186 |
def search_all_stages(query: str) -> Tuple[str, str, str, str, str]:
|
|
@@ -210,26 +239,37 @@ def generate_comparison_table(all_metrics: List[Dict]) -> str:
|
|
| 210 |
|
| 211 |
# Build markdown table
|
| 212 |
html = "## Stage-by-Stage Comparison\n\n"
|
| 213 |
-
html += "| Stage |
|
| 214 |
-
html += "
|
| 215 |
|
| 216 |
for name, metrics in zip(stage_names, all_metrics):
|
| 217 |
-
html += f"| **{name}** | {metrics['
|
| 218 |
|
| 219 |
-
# Calculate improvements
|
| 220 |
-
|
| 221 |
(
|
| 222 |
-
(all_metrics[3]["
|
| 223 |
-
/ all_metrics[0]["
|
| 224 |
* 100
|
| 225 |
)
|
| 226 |
-
if all_metrics[0]["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
else 0
|
| 228 |
)
|
| 229 |
|
| 230 |
html += "\n---\n\n"
|
| 231 |
html += "## Key Insights\n\n"
|
| 232 |
-
html += f"- **
|
|
|
|
| 233 |
html += f"- **Latency** stays under **{max(m['latency_ms'] for m in all_metrics)}ms** maintaining fast performance\n"
|
| 234 |
html += "- Each stage progressively enhances search relevance while keeping response times low\n"
|
| 235 |
html += "- Vector embeddings provide the biggest jump in semantic understanding\n"
|
|
@@ -363,17 +403,7 @@ with gr.Blocks(
|
|
| 363 |
scale=3,
|
| 364 |
elem_classes="search-box",
|
| 365 |
)
|
| 366 |
-
|
| 367 |
-
val = os.getenv("FIREWORKS_API_KEY", "") # pragma: allowlist secret
|
| 368 |
-
api_key_value = gr.Textbox( # pragma: allowlist secret
|
| 369 |
-
label="API Key",
|
| 370 |
-
type="password",
|
| 371 |
-
placeholder="Enter your Fireworks AI API key",
|
| 372 |
-
value=val,
|
| 373 |
-
container=True,
|
| 374 |
-
elem_classes="compact-input",
|
| 375 |
-
)
|
| 376 |
-
# Clean example query selector
|
| 377 |
with gr.Row():
|
| 378 |
gr.Markdown(
|
| 379 |
"**Try Example Queries:** Select a category and specificity level to auto-load an example"
|
|
|
|
| 2 |
import time
|
| 3 |
from typing import List, Dict, Tuple, Callable
|
| 4 |
from pathlib import Path
|
|
|
|
| 5 |
from config import (
|
| 6 |
GRADIO_THEME,
|
| 7 |
CUSTOM_CSS,
|
| 8 |
EXAMPLE_QUERIES_BY_CATEGORY,
|
| 9 |
)
|
| 10 |
from src.search.bm25_lexical_search import search_bm25
|
| 11 |
+
from src.search.vector_search import (
|
| 12 |
+
search_vector,
|
| 13 |
+
search_vector_with_expansion,
|
| 14 |
+
search_vector_with_reranking,
|
| 15 |
+
)
|
| 16 |
from src.data_prep.data_prep import load_clean_amazon_product_data
|
| 17 |
from src.constants.code_snippets import (
|
| 18 |
CODE_STAGE_1,
|
|
|
|
| 66 |
Args:
|
| 67 |
results: List of dicts with keys: product_name, description, main_category, secondary_category, score
|
| 68 |
stage_name: Name of the search stage
|
| 69 |
+
metrics: Dict with keys: top1_score, top5_avg, latency_ms
|
| 70 |
"""
|
| 71 |
html_parts = [
|
| 72 |
f"## 🔍 {stage_name}\n\n",
|
| 73 |
f"""
|
| 74 |
+
<div style="display: flex; gap: 16px; margin-bottom: 28px;">
|
| 75 |
+
<div class="metric-box" style="flex: 1;">
|
| 76 |
+
<div style="color: #6720FF; font-size: 0.85em; font-weight: 600; margin-bottom: 6px; letter-spacing: 0.5px;">TOP-1 SCORE</div>
|
| 77 |
+
<div style="font-size: 2em; font-weight: 700; color: #1E293B;">{metrics['top1_score']:.3f}</div>
|
| 78 |
+
<div style="color: #64748B; font-size: 0.75em; margin-top: 4px;">Best result</div>
|
| 79 |
+
</div>
|
| 80 |
<div class="metric-box" style="flex: 1;">
|
| 81 |
+
<div style="color: #6720FF; font-size: 0.85em; font-weight: 600; margin-bottom: 6px; letter-spacing: 0.5px;">TOP-5 AVG</div>
|
| 82 |
+
<div style="font-size: 2em; font-weight: 700; color: #1E293B;">{metrics['top5_avg']:.3f}</div>
|
| 83 |
+
<div style="color: #64748B; font-size: 0.75em; margin-top: 4px;">Overall quality</div>
|
| 84 |
</div>
|
| 85 |
<div class="metric-box" style="flex: 1;">
|
| 86 |
+
<div style="color: #6720FF; font-size: 0.85em; font-weight: 600; margin-bottom: 6px; letter-spacing: 0.5px;">LATENCY</div>
|
| 87 |
+
<div style="font-size: 2em; font-weight: 700; color: #1E293B;">{metrics['latency_ms']}<span style="font-size: 0.45em; color: #64748B; font-weight: 400;">ms</span></div>
|
| 88 |
+
<div style="color: #64748B; font-size: 0.75em; margin-top: 4px;">Response time</div>
|
| 89 |
</div>
|
| 90 |
</div>
|
| 91 |
""",
|
|
|
|
| 122 |
return sum(r["score"] for r in results) / len(results) if results else 0
|
| 123 |
|
| 124 |
|
| 125 |
+
def get_weighted_score(results: List[Dict]) -> float:
|
| 126 |
+
"""
|
| 127 |
+
Calculate position-weighted average score.
|
| 128 |
+
|
| 129 |
+
Top positions get higher weight (5x for #1, 4x for #2, etc.)
|
| 130 |
+
This rewards ranking quality - putting best results at the top.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
results: List of search results with 'score' field
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Weighted average score (0-1 scale)
|
| 137 |
+
"""
|
| 138 |
+
if not results:
|
| 139 |
+
return 0.0
|
| 140 |
+
|
| 141 |
+
weights = [5, 4, 3, 2, 1]
|
| 142 |
+
total_weight = sum(weights)
|
| 143 |
+
|
| 144 |
+
weighted_sum = sum((weights[i] * r["score"]) for i, r in enumerate(results[:5]))
|
| 145 |
+
|
| 146 |
+
return weighted_sum / total_weight
|
| 147 |
+
|
| 148 |
+
|
| 149 |
def search_stage_1(query: str) -> Tuple[str, Dict]:
|
| 150 |
"""Stage 1: Baseline BM25 keyword search."""
|
| 151 |
results, latency = run_search_function_and_time(query, search_bm25)
|
| 152 |
+
|
| 153 |
+
top1_score = results[0]["score"] / 5.0 if results else 0.0 # Normalize BM25 scores
|
| 154 |
+
top5_avg = get_average_score(results) / 5.0 if results else 0.0
|
| 155 |
|
| 156 |
metrics = {
|
| 157 |
+
"top1_score": min(1.0, top1_score),
|
| 158 |
+
"top5_avg": min(1.0, top5_avg),
|
| 159 |
"latency_ms": latency,
|
| 160 |
}
|
| 161 |
print(f"Searched BM25 for {query} in {latency}ms")
|
|
|
|
| 166 |
def search_stage_2(query: str) -> Tuple[str, Dict]:
|
| 167 |
"""Stage 2: Vector Embeddings using FAISS."""
|
| 168 |
results, latency = run_search_function_and_time(query, search_vector)
|
| 169 |
+
|
| 170 |
+
top1_score = results[0]["score"] if results else 0.0
|
| 171 |
+
top5_avg = get_average_score(results)
|
| 172 |
|
| 173 |
metrics = {
|
| 174 |
+
"top1_score": top1_score,
|
| 175 |
+
"top5_avg": top5_avg,
|
| 176 |
"latency_ms": latency,
|
| 177 |
}
|
| 178 |
print(f"Searched vector embeddings for '{query}' in {latency}ms")
|
|
|
|
| 182 |
|
| 183 |
def search_stage_3(query: str) -> Tuple[str, Dict]:
|
| 184 |
"""Stage 3: Query Expansion + Vector Embeddings."""
|
|
|
|
| 185 |
results, latency = run_search_function_and_time(query, search_vector_with_expansion)
|
| 186 |
+
|
| 187 |
+
top1_score = results[0]["score"] if results else 0.0
|
| 188 |
+
top5_avg = get_average_score(results)
|
| 189 |
|
| 190 |
metrics = {
|
| 191 |
+
"top1_score": top1_score,
|
| 192 |
+
"top5_avg": top5_avg,
|
| 193 |
"latency_ms": latency,
|
| 194 |
}
|
| 195 |
|
|
|
|
| 197 |
|
| 198 |
|
| 199 |
def search_stage_4(query: str) -> Tuple[str, Dict]:
|
| 200 |
+
"""Stage 4: Query Expansion + Vector Embeddings + Reranking."""
|
| 201 |
+
results, latency = run_search_function_and_time(query, search_vector_with_reranking)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
top1_score = results[0]["score"] if results else 0.0
|
| 204 |
+
top5_avg = get_average_score(results)
|
| 205 |
|
| 206 |
metrics = {
|
| 207 |
+
"top1_score": top1_score,
|
| 208 |
+
"top5_avg": top5_avg,
|
| 209 |
+
"latency_ms": latency,
|
| 210 |
}
|
| 211 |
|
| 212 |
+
return format_results(results, "Stage 4: Reranking", metrics), metrics
|
| 213 |
|
| 214 |
|
| 215 |
def search_all_stages(query: str) -> Tuple[str, str, str, str, str]:
|
|
|
|
| 239 |
|
| 240 |
# Build markdown table
|
| 241 |
html = "## Stage-by-Stage Comparison\n\n"
|
| 242 |
+
html += "| Stage | Top-1 Score | Top-5 Avg | Latency (ms) |\n"
|
| 243 |
+
html += "|-------|-------------|-----------|-------------|\n"
|
| 244 |
|
| 245 |
for name, metrics in zip(stage_names, all_metrics):
|
| 246 |
+
html += f"| **{name}** | {metrics['top1_score']:.3f} | {metrics['top5_avg']:.3f} | {metrics['latency_ms']} |\n"
|
| 247 |
|
| 248 |
+
# Calculate improvements based on top-5 average
|
| 249 |
+
top5_improvement = (
|
| 250 |
(
|
| 251 |
+
(all_metrics[3]["top5_avg"] - all_metrics[0]["top5_avg"])
|
| 252 |
+
/ all_metrics[0]["top5_avg"]
|
| 253 |
* 100
|
| 254 |
)
|
| 255 |
+
if all_metrics[0]["top5_avg"] > 0
|
| 256 |
+
else 0
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
top1_improvement = (
|
| 260 |
+
(
|
| 261 |
+
(all_metrics[3]["top1_score"] - all_metrics[0]["top1_score"])
|
| 262 |
+
/ all_metrics[0]["top1_score"]
|
| 263 |
+
* 100
|
| 264 |
+
)
|
| 265 |
+
if all_metrics[0]["top1_score"] > 0
|
| 266 |
else 0
|
| 267 |
)
|
| 268 |
|
| 269 |
html += "\n---\n\n"
|
| 270 |
html += "## Key Insights\n\n"
|
| 271 |
+
html += f"- **Top-1 Score** improves by **{top1_improvement:.0f}%** from baseline to final stage\n"
|
| 272 |
+
html += f"- **Top-5 Average** improves by **{top5_improvement:.0f}%** from baseline to final stage\n"
|
| 273 |
html += f"- **Latency** stays under **{max(m['latency_ms'] for m in all_metrics)}ms** maintaining fast performance\n"
|
| 274 |
html += "- Each stage progressively enhances search relevance while keeping response times low\n"
|
| 275 |
html += "- Vector embeddings provide the biggest jump in semantic understanding\n"
|
|
|
|
| 403 |
scale=3,
|
| 404 |
elem_classes="search-box",
|
| 405 |
)
|
| 406 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
with gr.Row():
|
| 408 |
gr.Markdown(
|
| 409 |
"**Try Example Queries:** Select a category and specificity level to auto-load an example"
|
src/config.py
CHANGED
|
@@ -3,7 +3,7 @@ import gradio as gr
|
|
| 3 |
# Fireworks AI Model Configuration
|
| 4 |
EMBEDDING_MODEL = "accounts/fireworks/models/qwen3-embedding-8b"
|
| 5 |
LLM_MODEL = "accounts/fireworks/models/qwen3-8b"
|
| 6 |
-
RERANKER_MODEL = "fireworks/qwen3-reranker-8b"
|
| 7 |
|
| 8 |
GRADIO_THEME = gr.themes.Base(
|
| 9 |
primary_hue=gr.themes.colors.purple,
|
|
|
|
| 3 |
# Fireworks AI Model Configuration
|
| 4 |
EMBEDDING_MODEL = "accounts/fireworks/models/qwen3-embedding-8b"
|
| 5 |
LLM_MODEL = "accounts/fireworks/models/qwen3-8b"
|
| 6 |
+
RERANKER_MODEL = "accounts/fireworks/models/qwen3-reranker-8b"
|
| 7 |
|
| 8 |
GRADIO_THEME = gr.themes.Base(
|
| 9 |
primary_hue=gr.themes.colors.purple,
|
src/fireworks/inference.py
CHANGED
|
@@ -2,14 +2,18 @@ import os
|
|
| 2 |
import yaml
|
| 3 |
from openai import OpenAI
|
| 4 |
from dotenv import load_dotenv
|
| 5 |
-
from typing import List
|
| 6 |
from pathlib import Path
|
| 7 |
-
|
|
|
|
| 8 |
|
| 9 |
load_dotenv()
|
| 10 |
|
| 11 |
_FILE_PATH = Path(__file__).parents[2]
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def load_prompt_library():
|
| 15 |
"""Load prompts from YAML configuration."""
|
|
@@ -17,15 +21,15 @@ def load_prompt_library():
|
|
| 17 |
return yaml.safe_load(f)
|
| 18 |
|
| 19 |
|
| 20 |
-
def create_client(
|
| 21 |
"""
|
| 22 |
Create client for FW inference
|
| 23 |
"""
|
| 24 |
-
api_key = os.getenv("FIREWORKS_API_KEY"
|
| 25 |
assert api_key is not None, "FIREWORKS_API_KEY not found in environment variables"
|
| 26 |
return OpenAI(
|
| 27 |
api_key=api_key,
|
| 28 |
-
base_url=
|
| 29 |
)
|
| 30 |
|
| 31 |
|
|
@@ -75,3 +79,47 @@ def expand_query(query: str) -> str:
|
|
| 75 |
|
| 76 |
expanded = response.choices[0].message.content.strip()
|
| 77 |
return expanded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import yaml
|
| 3 |
from openai import OpenAI
|
| 4 |
from dotenv import load_dotenv
|
| 5 |
+
from typing import List, Dict
|
| 6 |
from pathlib import Path
|
| 7 |
+
import requests
|
| 8 |
+
from src.config import EMBEDDING_MODEL, LLM_MODEL, RERANKER_MODEL
|
| 9 |
|
| 10 |
load_dotenv()
|
| 11 |
|
| 12 |
_FILE_PATH = Path(__file__).parents[2]
|
| 13 |
|
| 14 |
+
RERANK_URL = "https://api.fireworks.ai/inference/v1/rerank"
|
| 15 |
+
INFERENCE_URL = "https://api.fireworks.ai/inference/v1"
|
| 16 |
+
|
| 17 |
|
| 18 |
def load_prompt_library():
|
| 19 |
"""Load prompts from YAML configuration."""
|
|
|
|
| 21 |
return yaml.safe_load(f)
|
| 22 |
|
| 23 |
|
| 24 |
+
def create_client() -> OpenAI:
|
| 25 |
"""
|
| 26 |
Create client for FW inference
|
| 27 |
"""
|
| 28 |
+
api_key = os.getenv("FIREWORKS_API_KEY")
|
| 29 |
assert api_key is not None, "FIREWORKS_API_KEY not found in environment variables"
|
| 30 |
return OpenAI(
|
| 31 |
api_key=api_key,
|
| 32 |
+
base_url=INFERENCE_URL,
|
| 33 |
)
|
| 34 |
|
| 35 |
|
|
|
|
| 79 |
|
| 80 |
expanded = response.choices[0].message.content.strip()
|
| 81 |
return expanded
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def rerank_results(query: str, results: List[Dict], top_n: int = 5) -> List[Dict]:
|
| 85 |
+
"""
|
| 86 |
+
Rerank search results using Fireworks AI reranker model.
|
| 87 |
+
|
| 88 |
+
Takes search results and reranks them based on relevance to the query
|
| 89 |
+
using a specialized reranking model that considers cross-attention between
|
| 90 |
+
query and documents.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
query: Original search query
|
| 94 |
+
results: List of dictionaries containing product information and scores
|
| 95 |
+
top_n: Number of top results to return after reranking (default: 5)
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
List of dictionaries containing reranked product information with updated scores
|
| 99 |
+
"""
|
| 100 |
+
# Prepare documents as text for reranker (product name + description)
|
| 101 |
+
documents = [f"{r['product_name']}. {r['description']}" for r in results]
|
| 102 |
+
|
| 103 |
+
payload = {
|
| 104 |
+
"model": RERANKER_MODEL,
|
| 105 |
+
"query": query,
|
| 106 |
+
"documents": documents,
|
| 107 |
+
"top_n": top_n,
|
| 108 |
+
"return_documents": False,
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
headers = {
|
| 112 |
+
"Authorization": f"Bearer {os.getenv('FIREWORKS_API_KEY')}",
|
| 113 |
+
"Content-Type": "application/json",
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
response = requests.post(RERANK_URL, json=payload, headers=headers)
|
| 117 |
+
rerank_data = response.json()
|
| 118 |
+
|
| 119 |
+
# Map reranked results back to original product data
|
| 120 |
+
reranked_results = []
|
| 121 |
+
for item in rerank_data.get("data", []):
|
| 122 |
+
idx = item["index"]
|
| 123 |
+
reranked_results.append({**results[idx], "score": item["relevance_score"]})
|
| 124 |
+
|
| 125 |
+
return reranked_results
|
src/search/vector_search.py
CHANGED
|
@@ -2,13 +2,13 @@ import numpy as np
|
|
| 2 |
import faiss
|
| 3 |
from typing import List, Dict
|
| 4 |
from pathlib import Path
|
| 5 |
-
from src.fireworks.inference import get_embedding, expand_query
|
| 6 |
from constants.constants import FAISS_INDEX, PRODUCTS_DF
|
| 7 |
|
| 8 |
_FILE_PATH = Path(__file__).parents[2]
|
| 9 |
|
| 10 |
|
| 11 |
-
def search_vector(query: str, top_k: int =
|
| 12 |
"""
|
| 13 |
Search products using vector embeddings and FAISS for semantic search.
|
| 14 |
|
|
@@ -46,7 +46,7 @@ def search_vector(query: str, top_k: int = 10) -> List[Dict[str, any]]:
|
|
| 46 |
]
|
| 47 |
|
| 48 |
|
| 49 |
-
def search_vector_with_expansion(query: str, top_k: int =
|
| 50 |
"""
|
| 51 |
Search products using vector embeddings and FAISS for semantic search with query expansion.
|
| 52 |
|
|
@@ -64,3 +64,27 @@ def search_vector_with_expansion(query: str, top_k: int = 10) -> List[Dict[str,
|
|
| 64 |
print(f"Original: {query}")
|
| 65 |
print(f"Expanded: {expanded_query}")
|
| 66 |
return search_vector(expanded_query, top_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import faiss
|
| 3 |
from typing import List, Dict
|
| 4 |
from pathlib import Path
|
| 5 |
+
from src.fireworks.inference import get_embedding, expand_query, rerank_results
|
| 6 |
from constants.constants import FAISS_INDEX, PRODUCTS_DF
|
| 7 |
|
| 8 |
_FILE_PATH = Path(__file__).parents[2]
|
| 9 |
|
| 10 |
|
| 11 |
+
def search_vector(query: str, top_k: int = 5) -> List[Dict[str, any]]:
|
| 12 |
"""
|
| 13 |
Search products using vector embeddings and FAISS for semantic search.
|
| 14 |
|
|
|
|
| 46 |
]
|
| 47 |
|
| 48 |
|
| 49 |
+
def search_vector_with_expansion(query: str, top_k: int = 5) -> List[Dict[str, any]]:
|
| 50 |
"""
|
| 51 |
Search products using vector embeddings and FAISS for semantic search with query expansion.
|
| 52 |
|
|
|
|
| 64 |
print(f"Original: {query}")
|
| 65 |
print(f"Expanded: {expanded_query}")
|
| 66 |
return search_vector(expanded_query, top_k)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def search_vector_with_reranking(query: str, top_k: int = 5) -> List[Dict[str, any]]:
|
| 70 |
+
"""
|
| 71 |
+
Search products using vector embeddings and FAISS for semantic search with reranking.
|
| 72 |
+
|
| 73 |
+
This is Stage 4: semantic search using vector embeddings to understand
|
| 74 |
+
query meaning and intent beyond exact keyword matching, with reranking.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
query: Search query string
|
| 78 |
+
top_k: Number of top results to return (default: 10)
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
List of dictionaries containing product information with preserved cosine scores
|
| 82 |
+
"""
|
| 83 |
+
results = search_vector_with_expansion(query, top_k)
|
| 84 |
+
cosine_scores = {r["product_name"]: r["score"] for r in results}
|
| 85 |
+
reranked_results = rerank_results(query=query, results=results)
|
| 86 |
+
|
| 87 |
+
for r in reranked_results:
|
| 88 |
+
r["score"] = cosine_scores[r["product_name"]]
|
| 89 |
+
|
| 90 |
+
return reranked_results
|