import os import streamlit as st import numpy as np import time from sentence_transformers import SentenceTransformer import datetime import feedparser from huggingface_hub import hf_hub_download import faiss, pickle import aiohttp import asyncio import sqlite3 # ------------------- # Load prebuilt index # ------------------- def init_cache_db(): conn = sqlite3.connect("query_cache.db") c = conn.cursor() c.execute(""" CREATE TABLE IF NOT EXISTS cache ( id INTEGER PRIMARY KEY AUTOINCREMENT, query TEXT UNIQUE, answer TEXT, embedding BLOB, frequency INTEGER DEFAULT 1 ) """) conn.commit() return conn cache_conn = init_cache_db() def store_in_cache(query, answer, embedding): c = cache_conn.cursor() c.execute(""" INSERT OR REPLACE INTO cache (query, answer, embedding, frequency) VALUES (?, ?, ?, COALESCE( (SELECT frequency FROM cache WHERE query=?), 0 ) + 1) """, (query, answer, embedding.tobytes(), query) ) cache_conn.commit() def search_cache(query, embed_model, threshold=0.85): q_emb = embed_model.encode([query], convert_to_numpy=True)[0] c = cache_conn.cursor() c.execute("SELECT query, answer, embedding, frequency FROM cache") rows = c.fetchall() best_sim = -1 best_row = None for qry, ans, emb_blob, freq in rows: emb = np.frombuffer(emb_blob, dtype=np.float32).reshape(-1) sim = np.dot(q_emb, emb) / (np.linalg.norm(q_emb) * np.linalg.norm(emb)) if sim > threshold and sim > best_sim: best_sim = sim best_row = (qry, ans, freq) if best_row: return best_row[1] # return only answer return None # ------------------- # Load FAISS index + metadata # ------------------- @st.cache_resource def load_index(): faiss_path = hf_hub_download( repo_id="krishnasimha/health-chatbot-data", filename="health_index.faiss", repo_type="dataset" ) pkl_path = hf_hub_download( repo_id="krishnasimha/health-chatbot-data", filename="health_metadata.pkl", repo_type="dataset" ) index = faiss.read_index(faiss_path) with open(pkl_path, "rb") as f: metadata = pickle.load(f) embed_model = SentenceTransformer("all-MiniLM-L6-v2") return index, metadata, embed_model index, metadata, embed_model = load_index() # ------------------- # FAISS Benchmark # ------------------- def benchmark_faiss(n_queries=100, k=3): queries = ["What is diabetes?", "How to prevent malaria?", "Symptoms of dengue?"] query_embs = embed_model.encode(queries, convert_to_numpy=True) times = [] for _ in range(n_queries): q = query_embs[np.random.randint(0, len(query_embs))].reshape(1, -1) start = time.time() D, I = index.search(q, k) times.append(time.time() - start) avg_time = np.mean(times) * 1000 st.sidebar.write(f"⚡ FAISS Benchmark: {avg_time:.2f} ms/query over {n_queries} queries") # ------------------- # Chat session management # ------------------- if "chats" not in st.session_state: st.session_state.chats = {} if "current_chat" not in st.session_state: st.session_state.current_chat = "New Chat 1" st.session_state.chats["New Chat 1"] = [ {"role": "system", "content": "You are a helpful public health awareness chatbot."} ] st.sidebar.header("Chat Manager") if st.sidebar.button("➕ New Chat"): chat_count = len(st.session_state.chats) + 1 new_chat_name = f"New Chat {chat_count}" st.session_state.chats[new_chat_name] = [ {"role": "system", "content": "You are a helpful public health awareness chatbot."} ] st.session_state.current_chat = new_chat_name benchmark_faiss() # ------------------- # Most Asked Questions # ------------------- def get_top_cached_queries(limit=5): c = cache_conn.cursor() c.execute(""" SELECT query, frequency FROM cache ORDER BY frequency DESC LIMIT ? """, (limit,)) return c.fetchall() st.sidebar.subheader("🔥 Most Asked Questions") top_qs = get_top_cached_queries() for q, freq in top_qs: st.sidebar.write(f"**{q}** — used {freq} times") # ------------------- # Chat selector # ------------------- chat_list = list(st.session_state.chats.keys()) selected_chat = st.sidebar.selectbox( "Your chats:", chat_list, index=chat_list.index(st.session_state.current_chat), key="chat_select" ) st.session_state.current_chat = selected_chat new_name = st.sidebar.text_input("Rename Chat:", st.session_state.current_chat) if new_name and new_name != st.session_state.current_chat: if new_name not in st.session_state.chats: st.session_state.chats[new_name] = st.session_state.chats.pop(st.session_state.current_chat) st.session_state.current_chat = new_name # ------------------- # RSS News Fetcher (async) # ------------------- RSS_URL = "https://news.google.com/rss/search?q=health+disease+awareness&hl=en-IN&gl=IN&ceid=IN:en" async def fetch_rss_url(url): async with aiohttp.ClientSession() as session: async with session.get(url) as resp: return await resp.text() def fetch_news(): raw_xml = asyncio.run(fetch_rss_url(RSS_URL)) feed = feedparser.parse(raw_xml) articles = [] for entry in feed.entries[:5]: articles.append({ "title": entry.title, "link": entry.link, "published": entry.published }) return articles def update_news_hourly(): now = datetime.datetime.now() if "last_news_update" not in st.session_state or (now - st.session_state.last_news_update).seconds > 3600: st.session_state.last_news_update = now st.session_state.news_articles = fetch_news() # ------------------- # Async Together API # ------------------- async def async_together_chat(messages): url = "https://api.together.xyz/v1/chat/completions" headers = { "Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}", "Content-Type": "application/json", } payload = { "model": "deepseek-ai/DeepSeek-V3", "messages": messages, } async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, json=payload) as resp: result = await resp.json() return result["choices"][0]["message"]["content"] # ------------------- # Query function # ------------------- def retrieve_answer(query, k=3): # 1️⃣ Try fetch from cache cached_answer = search_cache(query, embed_model) if cached_answer: st.sidebar.success("⚡ Retrieved from cache") return cached_answer, [] # no FAISS sources # 2️⃣ If no cache → normal FAISS pipeline query_emb = embed_model.encode([query], convert_to_numpy=True) D, I = index.search(query_emb, k) retrieved = [metadata["texts"][i] for i in I[0]] sources = [metadata["sources"][i] for i in I[0]] context = "\n".join(retrieved) user_message = { "role": "user", "content": f"Answer based on the context below:\n\n{context}\n\nQuestion: {query}" } st.session_state.chats[st.session_state.current_chat].append(user_message) answer = asyncio.run(async_together_chat(st.session_state.chats[st.session_state.current_chat])) # 3️⃣ Save the new query + embedding + answer into cache store_in_cache(query, answer, query_emb[0]) st.session_state.chats[st.session_state.current_chat].append({"role": "assistant", "content": answer}) return answer, sources # ------------------- # Background news task # ------------------- async def background_news_updater(): while True: st.session_state.news_articles = fetch_news() await asyncio.sleep(3600) # refresh every hour if "news_task" not in st.session_state: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) st.session_state.news_task = loop.create_task(background_news_updater()) # ------------------- # Streamlit UI # ------------------- st.title(st.session_state.current_chat) update_news_hourly() st.subheader("📰 Latest Health Updates") if "news_articles" in st.session_state: for art in st.session_state.news_articles: st.markdown(f"**{art['title']}** \n[Read more]({art['link']}) \n*Published: {art['published']}*") st.write("---") user_query = st.text_input("Ask me about health, prevention, or awareness:") if user_query: with st.spinner("Searching knowledge base..."): answer, sources = retrieve_answer(user_query) st.write("### 💡 Answer") st.write(answer) st.write("### 📖 Sources") for src in sources: st.write(f"- {src}") for msg in st.session_state.chats[st.session_state.current_chat]: if msg["role"] == "user": st.write(f"🧑 **You:** {msg['content']}") elif msg["role"] == "assistant": st.write(f"🤖 **Bot:** {msg['content']}")