import streamlit as st import pandas as pd import numpy as np import jieba import requests import os import sys import subprocess from openai import OpenAI from rank_bm25 import BM25Okapi from sklearn.metrics.pairwise import cosine_similarity # ================= 1. 全局配置与 CSS注入 ================= API_KEY = os.getenv("SILICONFLOW_API_KEY") API_BASE = "https://api.siliconflow.cn/v1" EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" DATA_FILENAME = "comsol_embedded.parquet" DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" st.set_page_config( page_title="COMSOL Dark Expert", page_icon="🌌", layout="wide", initial_sidebar_state="expanded" ) # --- 注入自定义 CSS (保持之前的审美) --- st.markdown(""" """, unsafe_allow_html=True) # ================= 2. 核心逻辑(数据与RAG) ================= if not API_KEY: st.error("⚠️ 未检测到 API Key。请在 Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") st.stop() def download_with_curl(url, output_path): try: cmd = [ "curl", "-L", "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", "-o", output_path, "--fail", url ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: raise Exception(f"Curl failed: {result.stderr}") return True except Exception as e: print(f"Curl download error: {e}") return False def get_data_file_path(): possible_paths = [ DATA_FILENAME, os.path.join("/app", DATA_FILENAME), os.path.join("processed_data", DATA_FILENAME), os.path.join("src", DATA_FILENAME), os.path.join("..", DATA_FILENAME), "/tmp/" + DATA_FILENAME ] for path in possible_paths: if os.path.exists(path): return path download_target = "/app/" + DATA_FILENAME try: os.makedirs(os.path.dirname(download_target), exist_ok=True) except: download_target = "/tmp/" + DATA_FILENAME status_container = st.empty() status_container.info("📡 正在接入神经元网络... (下载核心数据中)") if download_with_curl(DATA_URL, download_target): status_container.empty() return download_target try: headers = {'User-Agent': 'Mozilla/5.0'} r = requests.get(DATA_URL, headers=headers, stream=True) r.raise_for_status() with open(download_target, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) status_container.empty() return download_target except Exception as e: st.error(f"❌ 数据链路中断。Error: {e}") st.stop() class FullRetriever: def __init__(self, parquet_path): try: self.df = pd.read_parquet(parquet_path) except Exception as e: st.error(f"Memory Matrix Load Failed: {e}"); st.stop() self.documents = self.df['content'].tolist() self.embeddings = np.stack(self.df['embedding'].values) self.bm25 = BM25Okapi([jieba.lcut(str(d).lower()) for d in self.documents]) self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) # Reranker 初始化移到这里,减少重复调用 self.rerank_headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"} self.rerank_url = f"{API_BASE}/rerank" def _get_emb(self, q): try: return self.client.embeddings.create(model=EMBEDDING_MODEL, input=[q]).data[0].embedding except: return [0.0] * 1024 def hybrid_search(self, query: str, top_k=5): # 1. Vector q_emb = self._get_emb(query) vec_scores = cosine_similarity([q_emb], self.embeddings)[0] vec_idx = np.argsort(vec_scores)[-100:][::-1] # 2. Keyword kw_idx = np.argsort(self.bm25.get_scores(jieba.lcut(query.lower())))[-100:][::-1] # 3. RRF Fusion fused = {} for r, i in enumerate(vec_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1) for r, i in enumerate(kw_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1) c_idxs = [x[0] for x in sorted(fused.items(), key=lambda x:x[1], reverse=True)[:50]] c_docs = [self.documents[i] for i in c_idxs] # 4. Rerank try: payload = {"model": RERANK_MODEL, "query": query, "documents": c_docs, "top_n": top_k} resp = requests.post(self.rerank_url, headers=self.rerank_headers, json=payload, timeout=10) results = resp.json().get('results', []) except: results = [{"index": i, "relevance_score": 0.0} for i in range(len(c_docs))][:top_k] final_res = [] context = "" for i, item in enumerate(results): orig_idx = c_idxs[item['index']] row = self.df.iloc[orig_idx] final_res.append({ "score": item['relevance_score'], "filename": row['filename'], "content": row['content'] }) context += f"[文档{i+1}]: {row['content']}\n\n" return final_res, context @st.cache_resource def load_engine(): real_path = get_data_file_path() return FullRetriever(real_path) # ================= 3. UI 主程序 ================= def main(): st.markdown("""
🌌
COMSOL DARK EXPERT
NEURAL SIMULATION ASSISTANT V4.1 Fixed
""", unsafe_allow_html=True) retriever = load_engine() with st.sidebar: st.markdown("### ⚙️ 控制台") top_k = st.slider("检索深度", 1, 10, 4) temp = st.slider("发散度", 0.0, 1.0, 0.3) st.markdown("---") if st.button("🗑️ 清空记忆 (Clear)", use_container_width=True): st.session_state.messages = [] st.session_state.current_refs = [] st.rerun() if "messages" not in st.session_state: st.session_state.messages = [] if "current_refs" not in st.session_state: st.session_state.current_refs = [] col_chat, col_evidence = st.columns([0.65, 0.35], gap="large") # ------------------ 处理输入源 ------------------ # 我们定义一个变量 user_input,不管它来自按钮还是输入框 user_input = None with col_chat: # 1. 如果历史为空,显示快捷按钮 if not st.session_state.messages: st.markdown("##### 💡 初始化提问序列 (Starter Sequence)") c1, c2, c3 = st.columns(3) # 点击按钮直接赋值给 user_input if c1.button("🌊 流固耦合接口设置"): user_input = "怎么设置流固耦合接口?" elif c2.button("⚡ 低频电磁场网格"): user_input = "低频电磁场网格划分有哪些技巧?" elif c3.button("📉 求解器不收敛"): user_input = "求解器不收敛通常怎么解决?" # 2. 渲染历史消息 for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) # 3. 处理底部输入框 (如果有按钮输入,这里会被跳过,因为 user_input 已经有值了) if not user_input: user_input = st.chat_input("输入指令或物理参数问题...") # ------------------ 统一处理消息追加 ------------------ if user_input: st.session_state.messages.append({"role": "user", "content": user_input}) # 强制刷新以立即在 UI 上显示用户的提问(对于按钮点击尤为重要) st.rerun() # ------------------ 统一触发生成 (修复的核心) ------------------ # 检查:如果有消息,且最后一条是 User 发的,说明需要 Assistant 回答 if st.session_state.messages and st.session_state.messages[-1]["role"] == "user": # 获取最后一条用户消息 last_query = st.session_state.messages[-1]["content"] with col_chat: # 确保在聊天栏显示 with st.spinner("🔍 正在扫描向量空间..."): refs, context = retriever.hybrid_search(last_query, top_k=top_k) st.session_state.current_refs = refs system_prompt = f"""你是一个COMSOL高级仿真专家。请基于提供的文档回答问题。 要求: 1. 语气专业、客观,逻辑严密。 2. 涉及物理公式时,**必须**使用 LaTeX 格式(例如 $E = mc^2$)。 3. 涉及步骤或参数对比时,优先使用 Markdown 列表或表格。 参考文档: {context} """ with st.chat_message("assistant"): resp_cont = st.empty() full_resp = "" client = OpenAI(base_url=API_BASE, api_key=API_KEY) try: stream = client.chat.completions.create( model=GEN_MODEL_NAME, messages=[{"role": "system", "content": system_prompt}] + st.session_state.messages[-6:], # 除去当前的System temperature=temp, stream=True ) for chunk in stream: txt = chunk.choices[0].delta.content if txt: full_resp += txt resp_cont.markdown(full_resp + " ▌") resp_cont.markdown(full_resp) st.session_state.messages.append({"role": "assistant", "content": full_resp}) except Exception as e: st.error(f"Neural Generation Failed: {e}") # ------------------ 渲染右侧证据栏 ------------------ with col_evidence: st.markdown("### 📚 神经记忆 (Evidence)") if st.session_state.current_refs: for i, ref in enumerate(st.session_state.current_refs): score = ref['score'] score_color = "#00ff41" if score > 0.6 else "#ffb700" if score > 0.4 else "#ff003c" with st.expander(f"📄 Doc {i+1}: {ref['filename'][:20]}...", expanded=(i==0)): st.markdown(f"""
Relevance: {score:.4f}
""", unsafe_allow_html=True) st.code(ref['content'], language="text") else: st.info("等待输入指令以检索知识库...") st.markdown("""
Waiting for query signal...
Index Status: Ready
Awaiting Input
""", unsafe_allow_html=True) if __name__ == "__main__": main()