import os from typing import Any, List, Tuple import chromadb from chromadb.config import Settings from datasets import load_dataset from sentence_transformers import SentenceTransformer import gradio as gr # Config MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" DATASET_NAME = "abisee/cnn_dailymail" DATASET_CONFIG = "3.0.0" DATASET_SLICE = "train[:100]" # keep small for quick startup CHUNK_MAX_WORDS = 500 CHUNK_OVERLAP = 50 CHROMA_DIR = "chromadb" COLLECTION_NAME = "cnn_dm_chunks" def chunk_text(text: str, max_words: int = CHUNK_MAX_WORDS, overlap: int = CHUNK_OVERLAP) -> List[str]: """Split text into overlapping word chunks.""" words = text.split() chunks = [] step = max_words - overlap n = len(words) for start in range(0, n, step): end = start + max_words chunks.append(" ".join(words[start:end])) return chunks def load_articles() -> List[str]: ds = load_dataset(DATASET_NAME, DATASET_CONFIG, split=DATASET_SLICE) return ds["article"] def build_chunks(articles: List[str]) -> Tuple[List[str], List[str]]: texts: List[str] = [] ids: List[str] = [] for i in range(len(articles)): article = articles[i] chunks = chunk_text(article) for j in range(len(chunks)): ids.append(f"{i}-{j}") texts.append(chunks[j]) return ids, texts def ensure_collection() -> Tuple[Any, SentenceTransformer]: """Get or build the Chroma collection, persisting to disk.""" client = chromadb.PersistentClient(path=CHROMA_DIR, settings=Settings(anonymized_telemetry=False)) collection = client.get_or_create_collection(name=COLLECTION_NAME) model = SentenceTransformer(MODEL_NAME) if collection.count() == 0: articles = load_articles() ids, texts = build_chunks(articles) embeddings = model.encode(texts, show_progress_bar=True, batch_size=32) collection.add(documents=texts, embeddings=embeddings.tolist(), ids=ids) return collection, model collection, model = ensure_collection() def semantic_search(query: str, k: int = 3) -> str: query_embedding = model.encode([query]) results = collection.query(query_embeddings=query_embedding.tolist(), n_results=k) docs = results["documents"][0] return "\n\n---\n".join(docs) demo = gr.Interface( fn=semantic_search, inputs=gr.Textbox(label="Enter your search query"), outputs=gr.Textbox(label="Top Matches"), title="Semantic Search Engine", description="Search over CNN/DailyMail chunks using semantic similarity." ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))