semantic_search / app.py
Abby Reynolds
Fix Chroma type hint for Spaces runtime
0adf235
import os
from typing import Any, List, Tuple
import chromadb
from chromadb.config import Settings
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import gradio as gr
# Config
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
DATASET_NAME = "abisee/cnn_dailymail"
DATASET_CONFIG = "3.0.0"
DATASET_SLICE = "train[:100]" # keep small for quick startup
CHUNK_MAX_WORDS = 500
CHUNK_OVERLAP = 50
CHROMA_DIR = "chromadb"
COLLECTION_NAME = "cnn_dm_chunks"
def chunk_text(text: str, max_words: int = CHUNK_MAX_WORDS, overlap: int = CHUNK_OVERLAP) -> List[str]:
"""Split text into overlapping word chunks."""
words = text.split()
chunks = []
step = max_words - overlap
n = len(words)
for start in range(0, n, step):
end = start + max_words
chunks.append(" ".join(words[start:end]))
return chunks
def load_articles() -> List[str]:
ds = load_dataset(DATASET_NAME, DATASET_CONFIG, split=DATASET_SLICE)
return ds["article"]
def build_chunks(articles: List[str]) -> Tuple[List[str], List[str]]:
texts: List[str] = []
ids: List[str] = []
for i in range(len(articles)):
article = articles[i]
chunks = chunk_text(article)
for j in range(len(chunks)):
ids.append(f"{i}-{j}")
texts.append(chunks[j])
return ids, texts
def ensure_collection() -> Tuple[Any, SentenceTransformer]:
"""Get or build the Chroma collection, persisting to disk."""
client = chromadb.PersistentClient(path=CHROMA_DIR, settings=Settings(anonymized_telemetry=False))
collection = client.get_or_create_collection(name=COLLECTION_NAME)
model = SentenceTransformer(MODEL_NAME)
if collection.count() == 0:
articles = load_articles()
ids, texts = build_chunks(articles)
embeddings = model.encode(texts, show_progress_bar=True, batch_size=32)
collection.add(documents=texts, embeddings=embeddings.tolist(), ids=ids)
return collection, model
collection, model = ensure_collection()
def semantic_search(query: str, k: int = 3) -> str:
query_embedding = model.encode([query])
results = collection.query(query_embeddings=query_embedding.tolist(), n_results=k)
docs = results["documents"][0]
return "\n\n---\n".join(docs)
demo = gr.Interface(
fn=semantic_search,
inputs=gr.Textbox(label="Enter your search query"),
outputs=gr.Textbox(label="Top Matches"),
title="Semantic Search Engine",
description="Search over CNN/DailyMail chunks using semantic similarity."
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))