Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Any, List, Tuple | |
| import chromadb | |
| from chromadb.config import Settings | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| # Config | |
| MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| DATASET_NAME = "abisee/cnn_dailymail" | |
| DATASET_CONFIG = "3.0.0" | |
| DATASET_SLICE = "train[:100]" # keep small for quick startup | |
| CHUNK_MAX_WORDS = 500 | |
| CHUNK_OVERLAP = 50 | |
| CHROMA_DIR = "chromadb" | |
| COLLECTION_NAME = "cnn_dm_chunks" | |
| def chunk_text(text: str, max_words: int = CHUNK_MAX_WORDS, overlap: int = CHUNK_OVERLAP) -> List[str]: | |
| """Split text into overlapping word chunks.""" | |
| words = text.split() | |
| chunks = [] | |
| step = max_words - overlap | |
| n = len(words) | |
| for start in range(0, n, step): | |
| end = start + max_words | |
| chunks.append(" ".join(words[start:end])) | |
| return chunks | |
| def load_articles() -> List[str]: | |
| ds = load_dataset(DATASET_NAME, DATASET_CONFIG, split=DATASET_SLICE) | |
| return ds["article"] | |
| def build_chunks(articles: List[str]) -> Tuple[List[str], List[str]]: | |
| texts: List[str] = [] | |
| ids: List[str] = [] | |
| for i in range(len(articles)): | |
| article = articles[i] | |
| chunks = chunk_text(article) | |
| for j in range(len(chunks)): | |
| ids.append(f"{i}-{j}") | |
| texts.append(chunks[j]) | |
| return ids, texts | |
| def ensure_collection() -> Tuple[Any, SentenceTransformer]: | |
| """Get or build the Chroma collection, persisting to disk.""" | |
| client = chromadb.PersistentClient(path=CHROMA_DIR, settings=Settings(anonymized_telemetry=False)) | |
| collection = client.get_or_create_collection(name=COLLECTION_NAME) | |
| model = SentenceTransformer(MODEL_NAME) | |
| if collection.count() == 0: | |
| articles = load_articles() | |
| ids, texts = build_chunks(articles) | |
| embeddings = model.encode(texts, show_progress_bar=True, batch_size=32) | |
| collection.add(documents=texts, embeddings=embeddings.tolist(), ids=ids) | |
| return collection, model | |
| collection, model = ensure_collection() | |
| def semantic_search(query: str, k: int = 3) -> str: | |
| query_embedding = model.encode([query]) | |
| results = collection.query(query_embeddings=query_embedding.tolist(), n_results=k) | |
| docs = results["documents"][0] | |
| return "\n\n---\n".join(docs) | |
| demo = gr.Interface( | |
| fn=semantic_search, | |
| inputs=gr.Textbox(label="Enter your search query"), | |
| outputs=gr.Textbox(label="Top Matches"), | |
| title="Semantic Search Engine", | |
| description="Search over CNN/DailyMail chunks using semantic similarity." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |