Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| # Add project root to path | |
| sys.path.append(os.path.dirname(__file__)) | |
| sys.path.append(str(Path(__file__).parent)) | |
| # ============================================================================ | |
| # RAG SYSTEM INITIALIZATION WITH PROPER ERROR HANDLING | |
| # ============================================================================ | |
| def initialize_rag_system(): | |
| """Initialize FAISS index and embedding cache with proper error handling""" | |
| print("🔧 Initializing FAISS index and cache...") | |
| # Create data directory if it doesn't exist | |
| data_dir = Path("/app/data") | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| # Check if FAISS index exists | |
| faiss_path = data_dir / "faiss_index.bin" | |
| cache_path = data_dir / "embedding_cache.db" | |
| if not faiss_path.exists(): | |
| print("⚠ WARNING: FAISS index not found at /app/data/faiss_index.bin") | |
| print(" Creating new FAISS index...") | |
| try: | |
| # Try to import and run initialization | |
| from scripts.initialize_rag import initialize_rag | |
| initialize_rag() | |
| print("✅ FAISS index created successfully") | |
| except ImportError as e: | |
| print(f"⚠️ Import error: {e}") | |
| print(" Running initialization script directly...") | |
| # Fallback: run as subprocess | |
| import subprocess | |
| result = subprocess.run( | |
| [sys.executable, "scripts/initialize_rag.py"], | |
| capture_output=True, | |
| text=True, | |
| cwd="/app" | |
| ) | |
| if result.returncode == 0: | |
| print("✅ FAISS index created via subprocess") | |
| else: | |
| print(f"⚠️ Failed to create FAISS index: {result.stderr}") | |
| return False | |
| except Exception as e: | |
| print(f"⚠️ Initialization error: {e}") | |
| return False | |
| else: | |
| print(f"✅ FAISS index found at {faiss_path}") | |
| # Check embedding cache | |
| if not cache_path.exists(): | |
| print("⚠ WARNING: Embedding cache not found at /app/data/embedding_cache.db") | |
| print(" It will be created automatically on first use.") | |
| # Create empty cache database | |
| import sqlite3 | |
| try: | |
| conn = sqlite3.connect(cache_path) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS embedding_cache ( | |
| text_hash TEXT PRIMARY KEY, | |
| embedding BLOB NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| access_count INTEGER DEFAULT 0 | |
| ) | |
| """) | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_created_at ON embedding_cache(created_at)") | |
| conn.commit() | |
| conn.close() | |
| print("✅ Embedding cache created") | |
| except Exception as e: | |
| print(f"⚠️ Could not create embedding cache: {e}") | |
| else: | |
| print(f"✅ Embedding cache found at {cache_path}") | |
| print("✅ Configuration validated successfully") | |
| return True | |
| # Run initialization | |
| try: | |
| init_success = initialize_rag_system() | |
| if not init_success: | |
| print("⚠️ RAG system initialization had issues, but continuing anyway...") | |
| except Exception as e: | |
| print(f"⚠️ Unexpected initialization error: {e}") | |
| print(" Continuing with limited functionality...") | |
| # ============================================================================ | |
| # GRADIO APP IMPORTS AND SETUP | |
| # ============================================================================ | |
| import gradio as gr | |
| # Global references to loaded systems | |
| _naive_rag = None | |
| _optimized_rag = None | |
| _no_compromise_rag = None | |
| _embedding_model = None | |
| def get_embedding_model(): | |
| """Load the embedding model once and reuse it across all RAG classes.""" | |
| global _embedding_model | |
| if _embedding_model is None: | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| print("Loading embedding model...") | |
| _embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| print("✅ Embedding model loaded successfully") | |
| except Exception as e: | |
| print(f"⚠️ Error loading embedding model: {e}") | |
| _embedding_model = None | |
| return _embedding_model | |
| def get_naive(): | |
| global _naive_rag | |
| if _naive_rag is None: | |
| try: | |
| from app.rag_naive import NaiveRAG | |
| print("Initializing Naive RAG...") | |
| _naive_rag = NaiveRAG() | |
| print("✅ Naive RAG initialized") | |
| except ImportError as e: | |
| print(f"⚠️ Could not import NaiveRAG: {e}") | |
| return None | |
| except Exception as e: | |
| print(f"⚠️ Error initializing Naive RAG: {e}") | |
| return None | |
| return _naive_rag | |
| def get_optimized(): | |
| global _optimized_rag | |
| if _optimized_rag is None: | |
| try: | |
| from app.rag_optimized import OptimizedRAG | |
| print("Initializing Optimized RAG...") | |
| _optimized_rag = OptimizedRAG() | |
| print("✅ Optimized RAG initialized") | |
| except ImportError as e: | |
| print(f"⚠️ Could not import OptimizedRAG: {e}") | |
| return None | |
| except Exception as e: | |
| print(f"⚠️ Error initializing Optimized RAG: {e}") | |
| return None | |
| return _optimized_rag | |
| def get_no_compromise(): | |
| global _no_compromise_rag | |
| if _no_compromise_rag is None: | |
| try: | |
| from app.no_compromise_rag import NoCompromiseRAG | |
| print("Initializing No-Compromise RAG...") | |
| _no_compromise_rag = NoCompromiseRAG() | |
| print("✅ No-Compromise RAG initialized") | |
| except ImportError as e: | |
| print(f"⚠️ Could not import NoCompromiseRAG: {e}") | |
| return None | |
| except Exception as e: | |
| print(f"⚠️ Error initializing No-Compromise RAG: {e}") | |
| return None | |
| return _no_compromise_rag | |
| def query_naive(question): | |
| if not question or question.strip() == "": | |
| return "Please enter a question.", "0 ms", "0", "No" | |
| try: | |
| rag = get_naive() | |
| if rag is None: | |
| return "RAG system not available. Check logs.", "0 ms", "0", "No" | |
| start = time.perf_counter() | |
| answer, chunks_used, cache_hit = rag.query(question) | |
| latency = (time.perf_counter() - start) * 1000 | |
| return answer, f"{latency:.1f} ms", str(chunks_used), "Yes" if cache_hit else "No" | |
| except Exception as e: | |
| error_msg = f"Error in Naive RAG: {str(e)}" | |
| print(error_msg) | |
| return error_msg, "0 ms", "0", "No" | |
| def query_optimized(question): | |
| if not question or question.strip() == "": | |
| return "Please enter a question.", "0 ms", "0", "No" | |
| try: | |
| rag = get_optimized() | |
| if rag is None: | |
| return "RAG system not available. Check logs.", "0 ms", "0", "No" | |
| start = time.perf_counter() | |
| answer, chunks_used, cache_hit = rag.query(question) | |
| latency = (time.perf_counter() - start) * 1000 | |
| return answer, f"{latency:.1f} ms", str(chunks_used), "Yes" if cache_hit else "No" | |
| except Exception as e: | |
| error_msg = f"Error in Optimized RAG: {str(e)}" | |
| print(error_msg) | |
| return error_msg, "0 ms", "0", "No" | |
| def query_no_compromise(question): | |
| if not question or question.strip() == "": | |
| return "Please enter a question.", "0 ms", "0", "No" | |
| try: | |
| rag = get_no_compromise() | |
| if rag is None: | |
| return "RAG system not available. Check logs.", "0 ms", "0", "No" | |
| start = time.perf_counter() | |
| answer, chunks_used, cache_hit = rag.query(question) | |
| latency = (time.perf_counter() - start) * 1000 | |
| return answer, f"{latency:.1f} ms", str(chunks_used), "Yes" if cache_hit else "No" | |
| except Exception as e: | |
| error_msg = f"Error in No-Compromise RAG: {str(e)}" | |
| print(error_msg) | |
| return error_msg, "0 ms", "0", "No" | |
| # ============================================================================ | |
| # BUILD THE GRADIO INTERFACE | |
| # ============================================================================ | |
| with gr.Blocks(title="RAG Latency Optimization", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # ⚡ RAG Latency Optimization | |
| ### Compare Naive, Optimized, and No‑Compromise RAG on CPU‑only hardware | |
| **Proven 2.7× speedup (247ms → 92ms)** – now interactive! | |
| """) | |
| with gr.Tabs(): | |
| # ----- Naive RAG tab ----- | |
| with gr.TabItem("🐢 Naive RAG (Baseline)"): | |
| with gr.Row(): | |
| question_naive = gr.Textbox(label="Your Question", lines=2, placeholder="e.g., What is RAG?") | |
| submit_naive = gr.Button("Ask", variant="primary") | |
| with gr.Row(): | |
| answer_naive = gr.Textbox(label="Answer", lines=4, interactive=False) | |
| with gr.Row(): | |
| latency_naive = gr.Textbox(label="Latency", interactive=False) | |
| chunks_naive = gr.Textbox(label="Chunks Used", interactive=False) | |
| cache_naive = gr.Textbox(label="Cache Hit", interactive=False) | |
| submit_naive.click( | |
| query_naive, | |
| inputs=question_naive, | |
| outputs=[answer_naive, latency_naive, chunks_naive, cache_naive] | |
| ) | |
| question_naive.submit( | |
| query_naive, | |
| inputs=question_naive, | |
| outputs=[answer_naive, latency_naive, chunks_naive, cache_naive] | |
| ) | |
| # ----- Optimized RAG tab ----- | |
| with gr.TabItem("⚡ Optimized RAG (Production)"): | |
| with gr.Row(): | |
| question_opt = gr.Textbox(label="Your Question", lines=2, placeholder="e.g., What is RAG?") | |
| submit_opt = gr.Button("Ask", variant="primary") | |
| with gr.Row(): | |
| answer_opt = gr.Textbox(label="Answer", lines=4, interactive=False) | |
| with gr.Row(): | |
| latency_opt = gr.Textbox(label="Latency", interactive=False) | |
| chunks_opt = gr.Textbox(label="Chunks Used", interactive=False) | |
| cache_opt = gr.Textbox(label="Cache Hit", interactive=False) | |
| submit_opt.click( | |
| query_optimized, | |
| inputs=question_opt, | |
| outputs=[answer_opt, latency_opt, chunks_opt, cache_opt] | |
| ) | |
| question_opt.submit( | |
| query_optimized, | |
| inputs=question_opt, | |
| outputs=[answer_opt, latency_opt, chunks_opt, cache_opt] | |
| ) | |
| # ----- No‑Compromise RAG tab ----- | |
| with gr.TabItem("🚀 No‑Compromise RAG (Max Speed)"): | |
| with gr.Row(): | |
| question_nc = gr.Textbox(label="Your Question", lines=2, placeholder="e.g., What is RAG?") | |
| submit_nc = gr.Button("Ask", variant="primary") | |
| with gr.Row(): | |
| answer_nc = gr.Textbox(label="Answer", lines=4, interactive=False) | |
| with gr.Row(): | |
| latency_nc = gr.Textbox(label="Latency", interactive=False) | |
| chunks_nc = gr.Textbox(label="Chunks Used", interactive=False) | |
| cache_nc = gr.Textbox(label="Cache Hit", interactive=False) | |
| submit_nc.click( | |
| query_no_compromise, | |
| inputs=question_nc, | |
| outputs=[answer_nc, latency_nc, chunks_nc, cache_nc] | |
| ) | |
| question_nc.submit( | |
| query_no_compromise, | |
| inputs=question_nc, | |
| outputs=[answer_nc, latency_nc, chunks_nc, cache_nc] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Architecture**: CPU‑only | **Embeddings**: `all-MiniLM-L6-v2` | **Vector Store**: FAISS | |
| **Caching**: SQLite (Optimized) + LRU memory | **Generation**: Simulated (real LLM can be plugged in) | |
| 💡 **Tip**: Press Enter to submit your question quickly! | |
| """) | |
| # ============================================================================ | |
| # LAUNCH THE APP | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| print("🚀 Starting RAG Latency Optimization App...") | |
| print("📍 Server will run on http://0.0.0.0:7860") | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |