Ariyan-Pro's picture
Update app.py
7beadc0 verified
import os
import sys
import time
from pathlib import Path
# Add project root to path
sys.path.append(os.path.dirname(__file__))
sys.path.append(str(Path(__file__).parent))
# ============================================================================
# RAG SYSTEM INITIALIZATION WITH PROPER ERROR HANDLING
# ============================================================================
def initialize_rag_system():
"""Initialize FAISS index and embedding cache with proper error handling"""
print("🔧 Initializing FAISS index and cache...")
# Create data directory if it doesn't exist
data_dir = Path("/app/data")
data_dir.mkdir(parents=True, exist_ok=True)
# Check if FAISS index exists
faiss_path = data_dir / "faiss_index.bin"
cache_path = data_dir / "embedding_cache.db"
if not faiss_path.exists():
print("⚠ WARNING: FAISS index not found at /app/data/faiss_index.bin")
print(" Creating new FAISS index...")
try:
# Try to import and run initialization
from scripts.initialize_rag import initialize_rag
initialize_rag()
print("✅ FAISS index created successfully")
except ImportError as e:
print(f"⚠️ Import error: {e}")
print(" Running initialization script directly...")
# Fallback: run as subprocess
import subprocess
result = subprocess.run(
[sys.executable, "scripts/initialize_rag.py"],
capture_output=True,
text=True,
cwd="/app"
)
if result.returncode == 0:
print("✅ FAISS index created via subprocess")
else:
print(f"⚠️ Failed to create FAISS index: {result.stderr}")
return False
except Exception as e:
print(f"⚠️ Initialization error: {e}")
return False
else:
print(f"✅ FAISS index found at {faiss_path}")
# Check embedding cache
if not cache_path.exists():
print("⚠ WARNING: Embedding cache not found at /app/data/embedding_cache.db")
print(" It will be created automatically on first use.")
# Create empty cache database
import sqlite3
try:
conn = sqlite3.connect(cache_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS embedding_cache (
text_hash TEXT PRIMARY KEY,
embedding BLOB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
access_count INTEGER DEFAULT 0
)
""")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_created_at ON embedding_cache(created_at)")
conn.commit()
conn.close()
print("✅ Embedding cache created")
except Exception as e:
print(f"⚠️ Could not create embedding cache: {e}")
else:
print(f"✅ Embedding cache found at {cache_path}")
print("✅ Configuration validated successfully")
return True
# Run initialization
try:
init_success = initialize_rag_system()
if not init_success:
print("⚠️ RAG system initialization had issues, but continuing anyway...")
except Exception as e:
print(f"⚠️ Unexpected initialization error: {e}")
print(" Continuing with limited functionality...")
# ============================================================================
# GRADIO APP IMPORTS AND SETUP
# ============================================================================
import gradio as gr
# Global references to loaded systems
_naive_rag = None
_optimized_rag = None
_no_compromise_rag = None
_embedding_model = None
def get_embedding_model():
"""Load the embedding model once and reuse it across all RAG classes."""
global _embedding_model
if _embedding_model is None:
try:
from sentence_transformers import SentenceTransformer
print("Loading embedding model...")
_embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("✅ Embedding model loaded successfully")
except Exception as e:
print(f"⚠️ Error loading embedding model: {e}")
_embedding_model = None
return _embedding_model
def get_naive():
global _naive_rag
if _naive_rag is None:
try:
from app.rag_naive import NaiveRAG
print("Initializing Naive RAG...")
_naive_rag = NaiveRAG()
print("✅ Naive RAG initialized")
except ImportError as e:
print(f"⚠️ Could not import NaiveRAG: {e}")
return None
except Exception as e:
print(f"⚠️ Error initializing Naive RAG: {e}")
return None
return _naive_rag
def get_optimized():
global _optimized_rag
if _optimized_rag is None:
try:
from app.rag_optimized import OptimizedRAG
print("Initializing Optimized RAG...")
_optimized_rag = OptimizedRAG()
print("✅ Optimized RAG initialized")
except ImportError as e:
print(f"⚠️ Could not import OptimizedRAG: {e}")
return None
except Exception as e:
print(f"⚠️ Error initializing Optimized RAG: {e}")
return None
return _optimized_rag
def get_no_compromise():
global _no_compromise_rag
if _no_compromise_rag is None:
try:
from app.no_compromise_rag import NoCompromiseRAG
print("Initializing No-Compromise RAG...")
_no_compromise_rag = NoCompromiseRAG()
print("✅ No-Compromise RAG initialized")
except ImportError as e:
print(f"⚠️ Could not import NoCompromiseRAG: {e}")
return None
except Exception as e:
print(f"⚠️ Error initializing No-Compromise RAG: {e}")
return None
return _no_compromise_rag
def query_naive(question):
if not question or question.strip() == "":
return "Please enter a question.", "0 ms", "0", "No"
try:
rag = get_naive()
if rag is None:
return "RAG system not available. Check logs.", "0 ms", "0", "No"
start = time.perf_counter()
answer, chunks_used, cache_hit = rag.query(question)
latency = (time.perf_counter() - start) * 1000
return answer, f"{latency:.1f} ms", str(chunks_used), "Yes" if cache_hit else "No"
except Exception as e:
error_msg = f"Error in Naive RAG: {str(e)}"
print(error_msg)
return error_msg, "0 ms", "0", "No"
def query_optimized(question):
if not question or question.strip() == "":
return "Please enter a question.", "0 ms", "0", "No"
try:
rag = get_optimized()
if rag is None:
return "RAG system not available. Check logs.", "0 ms", "0", "No"
start = time.perf_counter()
answer, chunks_used, cache_hit = rag.query(question)
latency = (time.perf_counter() - start) * 1000
return answer, f"{latency:.1f} ms", str(chunks_used), "Yes" if cache_hit else "No"
except Exception as e:
error_msg = f"Error in Optimized RAG: {str(e)}"
print(error_msg)
return error_msg, "0 ms", "0", "No"
def query_no_compromise(question):
if not question or question.strip() == "":
return "Please enter a question.", "0 ms", "0", "No"
try:
rag = get_no_compromise()
if rag is None:
return "RAG system not available. Check logs.", "0 ms", "0", "No"
start = time.perf_counter()
answer, chunks_used, cache_hit = rag.query(question)
latency = (time.perf_counter() - start) * 1000
return answer, f"{latency:.1f} ms", str(chunks_used), "Yes" if cache_hit else "No"
except Exception as e:
error_msg = f"Error in No-Compromise RAG: {str(e)}"
print(error_msg)
return error_msg, "0 ms", "0", "No"
# ============================================================================
# BUILD THE GRADIO INTERFACE
# ============================================================================
with gr.Blocks(title="RAG Latency Optimization", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ⚡ RAG Latency Optimization
### Compare Naive, Optimized, and No‑Compromise RAG on CPU‑only hardware
**Proven 2.7× speedup (247ms → 92ms)** – now interactive!
""")
with gr.Tabs():
# ----- Naive RAG tab -----
with gr.TabItem("🐢 Naive RAG (Baseline)"):
with gr.Row():
question_naive = gr.Textbox(label="Your Question", lines=2, placeholder="e.g., What is RAG?")
submit_naive = gr.Button("Ask", variant="primary")
with gr.Row():
answer_naive = gr.Textbox(label="Answer", lines=4, interactive=False)
with gr.Row():
latency_naive = gr.Textbox(label="Latency", interactive=False)
chunks_naive = gr.Textbox(label="Chunks Used", interactive=False)
cache_naive = gr.Textbox(label="Cache Hit", interactive=False)
submit_naive.click(
query_naive,
inputs=question_naive,
outputs=[answer_naive, latency_naive, chunks_naive, cache_naive]
)
question_naive.submit(
query_naive,
inputs=question_naive,
outputs=[answer_naive, latency_naive, chunks_naive, cache_naive]
)
# ----- Optimized RAG tab -----
with gr.TabItem("⚡ Optimized RAG (Production)"):
with gr.Row():
question_opt = gr.Textbox(label="Your Question", lines=2, placeholder="e.g., What is RAG?")
submit_opt = gr.Button("Ask", variant="primary")
with gr.Row():
answer_opt = gr.Textbox(label="Answer", lines=4, interactive=False)
with gr.Row():
latency_opt = gr.Textbox(label="Latency", interactive=False)
chunks_opt = gr.Textbox(label="Chunks Used", interactive=False)
cache_opt = gr.Textbox(label="Cache Hit", interactive=False)
submit_opt.click(
query_optimized,
inputs=question_opt,
outputs=[answer_opt, latency_opt, chunks_opt, cache_opt]
)
question_opt.submit(
query_optimized,
inputs=question_opt,
outputs=[answer_opt, latency_opt, chunks_opt, cache_opt]
)
# ----- No‑Compromise RAG tab -----
with gr.TabItem("🚀 No‑Compromise RAG (Max Speed)"):
with gr.Row():
question_nc = gr.Textbox(label="Your Question", lines=2, placeholder="e.g., What is RAG?")
submit_nc = gr.Button("Ask", variant="primary")
with gr.Row():
answer_nc = gr.Textbox(label="Answer", lines=4, interactive=False)
with gr.Row():
latency_nc = gr.Textbox(label="Latency", interactive=False)
chunks_nc = gr.Textbox(label="Chunks Used", interactive=False)
cache_nc = gr.Textbox(label="Cache Hit", interactive=False)
submit_nc.click(
query_no_compromise,
inputs=question_nc,
outputs=[answer_nc, latency_nc, chunks_nc, cache_nc]
)
question_nc.submit(
query_no_compromise,
inputs=question_nc,
outputs=[answer_nc, latency_nc, chunks_nc, cache_nc]
)
gr.Markdown("""
---
**Architecture**: CPU‑only | **Embeddings**: `all-MiniLM-L6-v2` | **Vector Store**: FAISS
**Caching**: SQLite (Optimized) + LRU memory | **Generation**: Simulated (real LLM can be plugged in)
💡 **Tip**: Press Enter to submit your question quickly!
""")
# ============================================================================
# LAUNCH THE APP
# ============================================================================
if __name__ == "__main__":
print("🚀 Starting RAG Latency Optimization App...")
print("📍 Server will run on http://0.0.0.0:7860")
demo.launch(server_name="0.0.0.0", server_port=7860)