leezhuuu commited on
Commit
b38ace2
·
verified ·
1 Parent(s): 0ba9cfe

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +257 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,259 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import jieba
5
+ import requests
6
+ import time
7
+ import os
8
+ from openai import OpenAI
9
+ from rank_bm25 import BM25Okapi
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+
12
+ # ================= 1. 安全配置与初始化 =================
13
+
14
+ # 尝试从环境变量获取 Key (Docker/HF Space 标准做法)
15
+ API_KEY = os.getenv("SILICONFLOW_API_KEY")
16
+
17
+ # 页面基础设置
18
+ st.set_page_config(
19
+ page_title="COMSOL Dark Expert",
20
+ page_icon="🌌",
21
+ layout="wide",
22
+ initial_sidebar_state="collapsed"
23
+ )
24
+
25
+ # 安全检查:如果没有配置 Key,拦截运行,避免公开应用报错泄露信息
26
+ if not API_KEY:
27
+ st.error("⚠️ 未检测到 API Key。")
28
+ st.info("请在 Hugging Face Space 的 'Settings' -> 'Variables and secrets' 中添加名为 `SILICONFLOW_API_KEY` 的 Secret。")
29
+ st.stop()
30
+
31
+ # API 配置
32
+ API_BASE = "https://api.siliconflow.cn/v1"
33
+ EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B"
34
+ RERANK_MODEL = "Qwen/Qwen3-Reranker-4B"
35
+ GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
36
+
37
+ # 数据源配置
38
+ DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet"
39
+ LOCAL_DATA_PATH = "/app/comsol_embedded.parquet" # Docker 容器内的路径,或者直接用 "comsol_embedded.parquet"
40
+
41
+ # ================= 2. 资源加载函数 (缓存化) =================
42
+
43
+ @st.cache_resource
44
+ def load_data_and_engine():
45
+ """下载数据并初始化检索引擎,全局只运行一次"""
46
+
47
+ # 1. 自动下载数据
48
+ if not os.path.exists(LOCAL_DATA_PATH):
49
+ try:
50
+ print(f"正在从 {DATA_URL} 下载数据...")
51
+ headers = {'User-Agent': 'Mozilla/5.0'} # 防止被简单的反爬拦截
52
+ r = requests.get(DATA_URL, headers=headers, stream=True)
53
+ r.raise_for_status()
54
+ with open(LOCAL_DATA_PATH, 'wb') as f:
55
+ for chunk in r.iter_content(chunk_size=8192):
56
+ f.write(chunk)
57
+ print("✅ 数据下载完成")
58
+ except Exception as e:
59
+ st.error(f"❌ 数据文件下载失败: {str(e)}")
60
+ st.stop()
61
+
62
+ # 2. 初始化引擎
63
+ return FullRetriever(LOCAL_DATA_PATH)
64
+
65
+ # ================= 3. 核心后端类 =================
66
+
67
+ class RerankClient:
68
+ def __init__(self, api_base, api_key, model):
69
+ self.api_url = f"{api_base}/rerank"
70
+ self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
71
+ self.model = model
72
+
73
+ def rerank(self, query: str, documents: list, top_n: int = 5):
74
+ if not documents: return []
75
+ payload = {"model": self.model, "query": query, "documents": documents, "top_n": top_n}
76
+ try:
77
+ response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=30)
78
+ response.raise_for_status()
79
+ return response.json()['results']
80
+ except Exception as e:
81
+ print(f"Rerank Warning: {e}")
82
+ # 降级处理:如果不通,按原顺序返回
83
+ return [{"index": i, "relevance_score": 0.0} for i in range(len(documents))]
84
+
85
+ class FullRetriever:
86
+ def __init__(self, parquet_path):
87
+ try:
88
+ self.df = pd.read_parquet(parquet_path)
89
+ except Exception as e:
90
+ raise RuntimeError(f"Parquet 读取失败: {e}")
91
+
92
+ self.documents = self.df['content'].tolist()
93
+ # 确保 embedding 列是 numpy 数组
94
+ self.embeddings = np.stack(self.df['embedding'].values)
95
+ self.bm25 = BM25Okapi([jieba.lcut(str(d).lower()) for d in self.documents])
96
+ self.client = OpenAI(base_url=API_BASE, api_key=API_KEY)
97
+ self.reranker = RerankClient(API_BASE, API_KEY, RERANK_MODEL)
98
+
99
+ def _get_emb(self, q):
100
+ try:
101
+ resp = self.client.embeddings.create(model=EMBEDDING_MODEL, input=[q])
102
+ return resp.data[0].embedding
103
+ except Exception:
104
+ return [0.0] * 1024 # 防止 API 挂掉时整个应用崩溃
105
+
106
+ def hybrid_search(self, query: str, top_k=5):
107
+ # 1. 向量检索
108
+ query_emb = self._get_emb(query)
109
+ vec_scores = cosine_similarity([query_emb], self.embeddings)[0]
110
+ vec_idx = np.argsort(vec_scores)[-100:][::-1]
111
+
112
+ # 2. 关键词检索
113
+ kw_idx = np.argsort(self.bm25.get_scores(jieba.lcut(query.lower())))[-100:][::-1]
114
+
115
+ # 3. RRF 融合
116
+ fused = {}
117
+ for r, i in enumerate(vec_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1)
118
+ for r, i in enumerate(kw_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1)
119
+
120
+ c_idxs = [x[0] for x in sorted(fused.items(), key=lambda x:x[1], reverse=True)[:50]]
121
+ c_docs = [self.documents[i] for i in c_idxs]
122
+
123
+ # 4. 重排序
124
+ results = self.reranker.rerank(query, c_docs, top_n=top_k)
125
+
126
+ final_res = []
127
+ context = ""
128
+ for i, item in enumerate(results):
129
+ orig_idx = c_idxs[item['index']]
130
+ row = self.df.iloc[orig_idx]
131
+ final_res.append({
132
+ "rank": i+1,
133
+ "score": item['relevance_score'],
134
+ "filename": row['filename'],
135
+ "content": row['content']
136
+ })
137
+ context += f"[文档{i+1}]: {row['content']}\n\n"
138
+ return final_res, context
139
+
140
+ # ================= 4. UI 渲染 =================
141
+
142
+ # CSS 样式注入
143
+ st.markdown("""
144
+ <style>
145
+ .stApp { background-color: #0E1117; color: #E0E0E0; }
146
+ .main-header {
147
+ background: linear-gradient(90deg, #0f2027 0%, #203a43 50%, #2c5364 100%);
148
+ padding: 1.5rem; border-radius: 0 0 15px 15px; color: #fff;
149
+ margin-bottom: 2rem; display: flex; align-items: center; justify-content: space-between;
150
+ }
151
+ .header-title { font-size: 1.8rem; font-weight: 700; color: white; margin:0;}
152
+ [data-testid="stChatMessage"] { background-color: #1E1E1E; border: 1px solid #333; }
153
+ .ref-card {
154
+ background-color: #161B22; border: 1px solid #30363D;
155
+ border-left: 4px solid #29B5E8; padding: 12px; margin-bottom: 12px;
156
+ }
157
+ .ref-title { font-weight: 600; color: #58A6FF; font-size: 0.95rem; }
158
+ .ref-snippet { font-size: 0.85rem; color: #8B949E; margin-top: 5px; font-family: monospace;}
159
+ </style>
160
+ """, unsafe_allow_html=True)
161
+
162
+ def main():
163
+ # 顶部栏
164
+ st.markdown("""
165
+ <div class="main-header">
166
+ <div>
167
+ <div class="header-title">COMSOL 智能仿真专家</div>
168
+ <div style="color: #bbb; font-size: 0.8rem;">V3.0 Dark | Secured Docker Edition</div>
169
+ </div>
170
+ </div>
171
+ """, unsafe_allow_html=True)
172
+
173
+ # 加载引擎 (包含下载逻辑)
174
+ with st.spinner("🚀 正在从云端同步数据并初始化神经中枢..."):
175
+ retriever = load_data_and_engine()
176
+
177
+ # 侧边栏
178
+ with st.sidebar:
179
+ st.header("🛠️ 参数控制")
180
+ top_k = st.slider("检索深度", 1, 10, 4)
181
+ temp = st.slider("发散度", 0.0, 1.0, 0.3)
182
+ if st.button("🧹 清空会话"):
183
+ st.session_state.messages = []
184
+ st.session_state.current_refs = []
185
+ st.rerun()
186
+
187
+ # 状态初始化
188
+ if "messages" not in st.session_state: st.session_state.messages = []
189
+ if "current_refs" not in st.session_state: st.session_state.current_refs = []
190
+
191
+ # 布局
192
+ col_chat, col_evidence = st.columns([0.65, 0.35], gap="large")
193
+
194
+ with col_chat:
195
+ for msg in st.session_state.messages:
196
+ with st.chat_message(msg["role"]):
197
+ st.markdown(msg["content"])
198
+
199
+ if prompt := st.chat_input("COMSOL 问题咨询..."):
200
+ st.session_state.messages.append({"role": "user", "content": prompt})
201
+ with st.chat_message("user"): st.markdown(prompt)
202
+
203
+ # 检索阶段
204
+ with st.status("📡 正在检索知识库...", expanded=False):
205
+ refs, context = retriever.hybrid_search(prompt, top_k=top_k)
206
+ st.session_state.current_refs = refs
207
+
208
+ # 生成阶段
209
+ system_prompt = f"""你是一个COMSOL专家。请根据以下参考文档回答问题。如果文档无相关信息,请明确告知。
210
+
211
+ 参考文档:
212
+ {context}
213
+ """
214
+
215
+ with st.chat_message("assistant"):
216
+ resp_cont = st.empty()
217
+ full_resp = ""
218
+
219
+ # 创建新的 Client 实例 (使用全局 API_KEY)
220
+ client = OpenAI(base_url=API_BASE, api_key=API_KEY)
221
+
222
+ try:
223
+ stream = client.chat.completions.create(
224
+ model=GEN_MODEL_NAME,
225
+ messages=[
226
+ {"role": "system", "content": system_prompt},
227
+ *st.session_state.messages[-6:] # 携带最近历史
228
+ ],
229
+ temperature=temp,
230
+ stream=True
231
+ )
232
+ for chunk in stream:
233
+ txt = chunk.choices[0].delta.content
234
+ if txt:
235
+ full_resp += txt
236
+ resp_cont.markdown(full_resp + "▌")
237
+ resp_cont.markdown(full_resp)
238
+ st.session_state.messages.append({"role": "assistant", "content": full_resp})
239
+ st.rerun() # 强制刷新以更新右侧证据
240
+ except Exception as e:
241
+ st.error(f"生成中断: {e}")
242
+
243
+ with col_evidence:
244
+ st.caption("📚 检索到的证据")
245
+ if st.session_state.current_refs:
246
+ for ref in st.session_state.current_refs:
247
+ st.markdown(f"""
248
+ <div class="ref-card">
249
+ <div class="ref-title">📄 {ref['filename']} (Score: {ref['score']:.2f})</div>
250
+ <div class="ref-snippet">{ref['content'][:120]}...</div>
251
+ </div>
252
+ """, unsafe_allow_html=True)
253
+ with st.expander("展开全文"):
254
+ st.text(ref['content'])
255
+ else:
256
+ st.info("暂无检索数据")
257
 
258
+ if __name__ == "__main__":
259
+ main()