Spaces:
Running
Running
File size: 19,449 Bytes
34f70c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 |
import os
import gradio as gr
import requests
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
import numpy as np
import faiss
from collections import deque
from langchain_core.embeddings import Embeddings
import threading
import queue
from langchain_core.messages import HumanMessage, AIMessage
from sentence_transformers import SentenceTransformer
import pickle
import torch
import time
from tqdm import tqdm
import logging
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 获取环境变量
os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "")
if not os.environ["OPENROUTER_API_KEY"]:
raise ValueError("OPENROUTER_API_KEY 未设置")
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY")
if not SILICONFLOW_API_KEY:
raise ValueError("SILICONFLOW_API_KEY 未设置")
BYTEZ_API_KEY = os.getenv("BYTEZ_API_KEY")
if not BYTEZ_API_KEY:
raise ValueError("BYTEZ_API_KEY no set")
# SiliconFlow API 配置
BYTEZ_API_URL = "https://api.bytez.com/models/v2/BAAI/bge-reranker-v2-m3"
SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/rerank"
# 自定义嵌入类,优化查询缓存
class SentenceTransformerEmbeddings(Embeddings):
def __init__(self, model_name="BAAI/bge-m3"):
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SentenceTransformer(model_name, device=device)
self.batch_size = 32 # 减小批次大小以适应低内存
self.query_cache = {}
self.cache_lock = threading.Lock()
def embed_documents(self, texts):
embeddings_list = []
batch_size = 1000 # 减小批次以降低内存压力
total_chunks = len(texts)
logger.info(f"生成嵌入,文档数: {total_chunks}")
with torch.no_grad():
for i in tqdm(range(0, total_chunks, batch_size), desc="生成嵌入"):
batch_texts = [text.page_content for text in texts[i:i + batch_size]]
batch_emb = self.model.encode(
batch_texts,
normalize_embeddings=True,
batch_size=self.batch_size
)
embeddings_list.append(batch_emb)
embeddings_array = np.vstack(embeddings_list)
np.save("embeddings.npy", embeddings_array)
return embeddings_array
def embed_query(self, text):
with self.cache_lock:
if text in self.query_cache:
return self.query_cache[text]
with torch.no_grad():
emb = self.model.encode([text], normalize_embeddings=True, batch_size=1)[0]
with self.cache_lock:
self.query_cache[text] = emb
if len(self.query_cache) > 1000: # 限制缓存大小
self.query_cache.pop(next(iter(self.query_cache)))
return emb
# 重排序函数
def rerank_documents(query, documents, top_n=15):
try:
doc_texts_with_meta = []
for doc in documents[:50]:
if isinstance(doc, tuple):
actual_doc = doc[0]
else:
actual_doc = doc
if hasattr(actual_doc, 'page_content'):
text = actual_doc.page_content[:2048]
book_meta = actual_doc.metadata.get("book", "unknow source")
doc_texts_with_meta.append((text, book_meta))
else:
logger.warning(f"skip the invalid texts: {type(doc)}")
headers = {"Authorization": f"Bearer {BYTEZ_API_KEY}", "Content-Type": "application/json"}
data = {"query": query, "text": [text for text, _ in doc_texts_with_meta], "top_n": top_n}
response = requests.post(BYTEZ_API_URL, headers=headers, json=data)
response.raise_for_status()
result = response.json()
'''
import json
print("---api result---")
print(json.dumps(result, indent=2, ensure_ascii=False))
print("-----------------------")
'''
reranked_docs = []
for res in result["output"]:
score = res["score"]
pass
reranked_results = []
for i, res in enumerate(result["output"]):
score = res["score"]
if i < len(documents):
if isinstance(documents[i], tuple):
original_doc = documents[i][0]
else:
original_doc = documents[i]
reranked_results.append((original_doc, score))
return sorted(reranked_results, key=lambda x: x[1], reverse=True)[:top_n]
except Exception as e:
logger.error(f"重排序失败: {str(e)}")
raise
# 构建 HNSW 索引
def build_hnsw_index(knowledge_base_path, index_path):
loader = DirectoryLoader(knowledge_base_path, glob="*.txt", loader_cls=lambda path: TextLoader(path, encoding="utf-8"))
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
for i, doc in enumerate(texts):
doc.metadata["book"] = os.path.basename(doc.metadata.get("source", "未知来源")).replace(".txt", "")
embeddings_array = embeddings.embed_documents(texts)
dimension = embeddings_array.shape[1]
index = faiss.IndexHNSWFlat(dimension, 16)
index.hnsw.efConstruction = 100
index.add(embeddings_array)
vector_store = FAISS.from_embeddings([(doc.page_content, embeddings_array[i]) for i, doc in enumerate(texts)], embeddings)
vector_store.index = index
vector_store.save_local(index_path)
with open("chunks.pkl", "wb") as f:
pickle.dump(texts, f)
return vector_store, texts
# 初始化嵌入模型和索引
embeddings = SentenceTransformerEmbeddings()
index_path = "faiss_index_hnsw_new"
knowledge_base_path = "knowledge_base"
if not os.path.exists(index_path):
vector_store, all_documents = build_hnsw_index(knowledge_base_path, index_path)
else:
vector_store = FAISS.load_local(index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
vector_store.index.hnsw.efSearch = 200 # 降低 efSearch 以提升速度
with open("chunks.pkl", "rb") as f:
all_documents = pickle.load(f)
# 初始化 LLM
llm = ChatOpenAI(
model="tngtech/deepseek-r1t2-chimera:free",
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
timeout=100,
temperature=0.3,
max_tokens=130000,
streaming=True
)
# 提示词模板
'''
prompt_template = PromptTemplate(
input_variables=["context", "question", "chat_history"],
template="""
你是一个研究李敖的专家,根据用户提出的问题{question}、最近7轮对话历史{chat_history}以及从李敖相关书籍和评论中检索的至少10篇文本内容{context}回答问题。
在回答时,请注意以下几点:
- 结合李敖的写作风格和思想,筛选出与问题和对话历史最相关的检索内容,避免无关信息。
- 必须在回答中引用至少10篇不同的文本内容,引用格式为[引用: 文本序号],例如[引用: 1][引用: 2],并确保每篇文本在回答中都有明确使用。
- 在回答的末尾,必须以“引用文献”标题列出所有引用的文本序号及其内容摘要(每篇不超过50字)以及具体的书目信息(例如书名和章节),格式为:
- 引用文献:
1. [文本 1] 摘要... 出自:书名,第X页/章节。
2. [文本 2] 摘要... 出自:书名,第X页/章节。
(依此类推,至少10篇)
- 如果问题涉及李敖对某人或某事的评价,优先引用李敖的直接言论或文字,并说明出处。
- 回答应结构化、分段落,确保逻辑清晰,语言生动,类似李敖的犀利风格。
- 如果检索内容和历史不足以直接回答问题,可根据李敖的性格和观点推测其可能的看法,但需说明这是推测。
- 只能基于提供的知识库内容{context}和对话历史{chat_history}回答,不得引入外部信息。
- 对于列举类问题,控制在10个要点以内,并优先提供最相关项。
- 如果回答较长,结构化分段总结,分点作答控制在8个点以内。
- 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
- 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
- 你的回答应该综合多个相关知识库内容来回答,不能重复引用一个知识库内容。
- 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
"""
)
'''
prompt_template = PromptTemplate(
input_variables=["context", "question", "chat_history"],
template="""
# 身份与风格设定:文坛巨匠李敖
你是一位特立独行、才华横溢的文坛巨匠——李敖。
你的职责是以李敖的风格,对用户的问题进行犀利、批判、且充满文采的回答。
## 风格要求(请严格遵守):
1. 语言风格:桀骜不驯、自信、尖锐,善用反讽和黑色幽默。
2. 修辞手法:必须灵活运用排比句、引经据典,以及李敖的经典口头禅。
3. 语气倾向:决不能使用温和、中立、模糊或讨好的词汇。
---
## 强制性思考过程(Chain-of-Thought)
在生成最终的回答之前,你**必须**在 `<思考过程>` 标签内,按照以下四个强制性步骤完成你的推理和规划。
### 步骤 A:核心观点解析与策略制定
分析用户问题【{question}】和对话历史【{chat_history}】,将其转化为一个李敖会批判的社会或历史议题。从知识库【{context}】中,锁定最能支撑李敖批判立场的 1-2 个核心观点。
### 步骤 B:RAG 内容的整合与提炼
将知识库中**至少10篇**文本内容转化为可用的、具有李敖色彩的素材。确保选取的10篇文本都与问题和对话历史最相关,并能提供证据链。
### 步骤 C:初稿生成与风格检验
根据步骤 A 的规划和步骤 B 的提炼素材,写出回答的**初稿全文**。检查初稿是否符合李敖的风格、批判性和排比要求。
### 步骤 D:最终润色与定稿
根据风格检验结果,对初稿中不符合李敖风格的部分进行**具体修改**,生成最终回答。
---
## 通用指令与格式要求
请**只能基于**提供的知识库内容【{context}】和对话历史【{chat_history}】回答,不得引入外部信息。
### 引用规范(绝对强制):
1. **引用数量:** 必须在回答中引用**至少10篇不同的文本内容**,并确保每篇文本在回答中都有明确使用。
2. **格式要求:** 引用格式为**[引用: 文本序号]**,例如[引用: 1][引用: 2]。
3. **参考文献:** 在回答的末尾,**必须**以“引用文献”标题列出所有引用的文本序号及其内容摘要(每篇不超过50字)以及具体的书目信息(例如书名和章节)。
- 引用文献:
1. [文本 1] 摘要... 出自:书名,第X页/章节。
2. [文本 2] 摘要... 出自:书名,第X页/章节。
(依此类推,至少10篇)
### 回答规范:
1. **优先引用:** 如果问题涉及李敖对某人或某事的评价,优先引用李敖的直接言论或文字,并说明出处。
2. **结构化:** 回答应结构化、分段落,确保逻辑清晰,语言生动。
3. **篇幅控制:** 对于列举类问题,控制在10个要点以内,并优先提供最相关项;如果回答较长,分点作答控制在8个点以内。
4. **推测说明:** 如果检索内容和历史不足以直接回答问题,可根据李敖的性格和观点推测其可能的看法,但**需说明这是推测**。
---
## 最终输出格式
你的输出**必须**包含两个部分,并严格使用以下 Markdown 标签:
1. **<思考过程>**:包含你执行上面“强制性思考过程”的全部内容。
2. **<最终回答>**:只包含最终润色后的回答文本和末尾的“引用文献”列表。
"""
)
# 对话历史管理
class ConversationHistory:
def __init__(self, max_length=5): # 减少历史轮数
self.history = deque(maxlen=max_length)
def add_turn(self, question, answer):
self.history.append((question, answer))
def get_history(self):
return [(q, a) for q, a in self.history]
# 用户会话状态
class UserSession:
def __init__(self):
self.conversation = ConversationHistory()
self.output_queue = queue.Queue()
self.stop_flag = threading.Event()
# 生成回答
def generate_answer_thread(question, session):
stop_flag = session.stop_flag
output_queue = session.output_queue
conversation = session.conversation
stop_flag.clear()
try:
# 打印用户问题到控制台
logger.info(f"用户问题: {question}")
history_list = conversation.get_history()
history_text = "\n".join([f"问: {q}\n答: {a}" for q, a in history_list[-3:]]) # 只用最后3轮
query_with_context = f"{history_text}\n问题: {question}" if history_text else question
# 异步生成查询嵌入
embed_queue = queue.Queue()
def embed_task():
start = time.time()
emb = embeddings.embed_query(query_with_context)
embed_queue.put((emb, time.time() - start))
embed_thread = threading.Thread(target=embed_task)
embed_thread.start()
embed_thread.join()
query_embedding, embed_time = embed_queue.get()
if stop_flag.is_set():
output_queue.put("生成已停止")
return
# 初始检索
start = time.time()
docs_with_scores = vector_store.similarity_search_with_score_by_vector(query_embedding, k=50)
search_time = time.time() - start
if stop_flag.is_set():
output_queue.put("生成已停止")
return
# 重排序
initial_docs = [doc for doc, _ in docs_with_scores]
start = time.time()
reranked_docs_with_scores = rerank_documents(query_with_context, initial_docs)
rerank_time = time.time() - start
final_docs = [doc for doc, _ in reranked_docs_with_scores][:10]
# 打印重排序结果到控制台
logger.info("重排序结果(最终保留的片段及其得分):")
for i, (doc, score) in enumerate(reranked_docs_with_scores[:10], 1):
logger.info(f"片段 {i}:")
logger.info(f" 内容: {doc.page_content[:100]}...")
logger.info(f" 来源: {doc.metadata.get('book', '未知来源')}")
logger.info(f" 得分: {score:.4f}")
context = "\n".join([f"[文本 {i+1}] {doc.page_content} (出处: {doc.metadata.get('book')})" for i, doc in enumerate(final_docs)])
prompt = prompt_template.format(context=context, question=question, chat_history=history_text)
# 将时间信息加入回答开头
timing_info = (
f"处理时间统计:\n"
f"- 嵌入时间: {embed_time:.2f} 秒\n"
f"- 检索时间: {search_time:.2f} 秒\n"
f"- 重排序时间: {rerank_time:.2f} 秒\n\n"
)
answer = timing_info
output_queue.put(answer) # 先显示时间信息
# LLM 生成回答
start = time.time()
for chunk in llm.stream([HumanMessage(content=prompt)]):
if stop_flag.is_set():
output_queue.put(answer + "\n(生成已停止)")
return
answer += chunk.content
output_queue.put(answer)
llm_time = time.time() - start
answer += f"\n\n生成耗时: {llm_time:.2f} 秒"
output_queue.put(answer)
conversation.add_turn(question, answer)
output_queue.put(answer)
except Exception as e:
output_queue.put(f"Error: {str(e)}")
# Gradio 接口
def answer_question(question, session_state):
if session_state is None:
session_state = UserSession()
thread = threading.Thread(target=generate_answer_thread, args=(question, session_state))
thread.start()
while thread.is_alive() or not session_state.output_queue.empty():
try:
output = session_state.output_queue.get(timeout=0.1)
yield output, session_state
except queue.Empty:
continue
def stop_generation(session_state):
if session_state:
session_state.stop_flag.set()
return "生成已停止"
def clear_conversation():
return "对话已清空", UserSession()
# 自动提问功能:每天触发一次“介绍一下李敖”
def auto_ask_question():
auto_session = UserSession()
last_run_time = 0
interval = 24 * 60 * 60 # 24小时(单位:秒)
while True:
current_time = time.time()
if current_time - last_run_time >= interval:
logger.info("自动触发问题:介绍一下李敖")
thread = threading.Thread(target=generate_answer_thread, args=("介绍一下李敖", auto_session))
thread.start()
thread.join() # 等待回答生成完成
last_run_time = current_time
time.sleep(60) # 每分钟检查一次,避免占用过多资源
# Gradio 界面
with gr.Blocks(title="AI李敖助手") as interface:
gr.Markdown("## AI李敖助手")
gr.Markdown("### 作者:爱华山樱")
gr.Markdown("基于李敖163本相关书籍构建的知识库,支持上下文关联,记住最近5轮对话,输入问题以获取李敖风格的回答。")
gr.Markdown("提问之后红框存在期间表示正在生成回答,如果红框消失之后答案没出来,说明生成有问题(偶尔会这样),重来一次即可。")
session_state = gr.State(value=None)
question_input = gr.Textbox(label="问题")
submit_button = gr.Button("提交")
clear_button = gr.Button("新建对话")
stop_button = gr.Button("停止生成")
output_text = gr.Textbox(label="回答", interactive=False)
submit_button.click(fn=answer_question, inputs=[question_input, session_state], outputs=[output_text, session_state])
clear_button.click(fn=clear_conversation, inputs=None, outputs=[output_text, session_state])
stop_button.click(fn=stop_generation, inputs=[session_state], outputs=output_text)
if __name__ == "__main__":
# 启动自动提问线程
auto_thread = threading.Thread(target=auto_ask_question, daemon=True)
auto_thread.start()
# 启动 Gradio 界面
interface.launch(share=True)
|