samlax12 commited on
Commit
d20b012
·
verified ·
1 Parent(s): 7a1e626

Update modules/knowledge_base/vector_store.py

Browse files
Files changed (1) hide show
  1. modules/knowledge_base/vector_store.py +186 -186
modules/knowledge_base/vector_store.py CHANGED
@@ -1,187 +1,187 @@
1
- from typing import List, Dict
2
- import requests
3
- import numpy as np
4
- from elasticsearch import Elasticsearch
5
- import urllib3
6
- from dotenv import load_dotenv
7
- import os
8
-
9
- load_dotenv()
10
-
11
- urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
12
-
13
- class VectorStore:
14
- def __init__(self):
15
- # ES 8.x 的连接配置
16
- self.es = Elasticsearch(
17
- "https://samlax12-elastic.hf.space",
18
- basic_auth=("elastic", os.getenv("PASSWORD")),
19
- verify_certs=False,
20
- request_timeout=30,
21
- # 忽略系统索引警告
22
- headers={"accept": "application/vnd.elasticsearch+json; compatible-with=8"},
23
- )
24
- self.api_key = os.getenv("API_KEY")
25
- self.api_base = os.getenv("BASE_URL")
26
-
27
- def get_embedding(self, text: str) -> List[float]:
28
- """调用SiliconFlow的embedding API获取向量"""
29
- headers = {
30
- "Authorization": f"Bearer {self.api_key}",
31
- "Content-Type": "application/json"
32
- }
33
-
34
- response = requests.post(
35
- f"{self.api_base}/embeddings",
36
- headers=headers,
37
- json={
38
- "model": "BAAI/bge-m3",
39
- "input": text
40
- }
41
- )
42
-
43
- if response.status_code == 200:
44
- return response.json()["data"][0]["embedding"]
45
- else:
46
- raise Exception(f"Error getting embedding: {response.text}")
47
-
48
- def store(self, documents: List[Dict], index_name: str) -> None:
49
- """将文档存储到 Elasticsearch"""
50
- # 创建索引(如果不存在)
51
- if not self.es.indices.exists(index=index_name):
52
- self.create_index(index_name)
53
-
54
- # 获取当前索引中的文档数量
55
- try:
56
- response = self.es.count(index=index_name)
57
- last_id = response['count'] - 1 # 文档数量减1作为最后的ID
58
- if last_id < 0:
59
- last_id = -1
60
- except Exception as e:
61
- print(f"获取文档数量时出错,假设为-1: {str(e)}")
62
- last_id = -1
63
-
64
- # 批量索引文档
65
- bulk_data = []
66
- for i, doc in enumerate(documents, start=last_id + 1):
67
- # 获取文档向量
68
- vector = self.get_embedding(doc['content'])
69
-
70
- # 准备索引数据
71
- bulk_data.append({
72
- "index": {
73
- "_index": index_name,
74
- "_id": f"doc_{i}"
75
- }
76
- })
77
-
78
- # 构建文档数据,包含新的img_url字段
79
- doc_data = {
80
- "content": doc['content'],
81
- "vector": vector,
82
- "metadata": {
83
- "file_name": doc['metadata'].get('file_name', '未知文件'),
84
- "source": doc['metadata'].get('source', ''),
85
- "page": doc['metadata'].get('page', ''),
86
- "img_url": doc['metadata'].get('img_url', '') # 添加img_url字段
87
- }
88
- }
89
- bulk_data.append(doc_data)
90
-
91
- # 批量写入
92
- if bulk_data:
93
- response = self.es.bulk(operations=bulk_data, refresh=True)
94
- if response.get('errors'):
95
- print("批量写入时出现错误:", response)
96
-
97
- def get_files_in_index(self, index_name: str) -> List[str]:
98
- """获取索引中的所有文件名"""
99
- try:
100
- response = self.es.search(
101
- index=index_name,
102
- body={
103
- "size": 0,
104
- "aggs": {
105
- "unique_files": {
106
- "terms": {
107
- "field": "metadata.file_name",
108
- "size": 1000
109
- }
110
- }
111
- }
112
- }
113
- )
114
-
115
- files = [bucket['key'] for bucket in response['aggregations']['unique_files']['buckets']]
116
- return sorted(files)
117
- except Exception as e:
118
- print(f"获取文件列表时出错: {str(e)}")
119
- return []
120
-
121
- def create_index(self, index_name: str):
122
- """创建 Elasticsearch 索引"""
123
- settings = {
124
- "mappings": {
125
- "properties": {
126
- "content": {"type": "text"},
127
- "vector": {
128
- "type": "dense_vector",
129
- "dims": 1024
130
- },
131
- "metadata": {
132
- "properties": {
133
- "file_name": {
134
- "type": "keyword",
135
- "ignore_above": 256
136
- },
137
- "source": {
138
- "type": "keyword"
139
- },
140
- "page": {
141
- "type": "keyword"
142
- },
143
- "img_url": { # 新增图片URL字段
144
- "type": "keyword",
145
- "ignore_above": 2048
146
- }
147
- }
148
- }
149
- }
150
- }
151
- }
152
-
153
- # 如果索引已存在,先删除
154
- if self.es.indices.exists(index=index_name):
155
- self.es.indices.delete(index=index_name)
156
-
157
- self.es.indices.create(index=index_name, body=settings)
158
-
159
- def delete_index(self, index_id: str) -> bool:
160
- """删除一个索引"""
161
- try:
162
- if self.es.indices.exists(index=index_id):
163
- self.es.indices.delete(index=index_id)
164
- return True
165
- return False
166
- except Exception as e:
167
- print(f"删除索引时出错: {str(e)}")
168
- return False
169
-
170
- def delete_document(self, index_id: str, file_name: str) -> bool:
171
- """根据文件名删除文档"""
172
- try:
173
- response = self.es.delete_by_query(
174
- index=index_id,
175
- body={
176
- "query": {
177
- "term": {
178
- "metadata.file_name": file_name
179
- }
180
- }
181
- },
182
- refresh=True
183
- )
184
- return True
185
- except Exception as e:
186
- print(f"删除文档时出错: {str(e)}")
187
  return False
 
1
+ from typing import List, Dict
2
+ import requests
3
+ import numpy as np
4
+ from elasticsearch import Elasticsearch
5
+ import urllib3
6
+ from dotenv import load_dotenv
7
+ import os
8
+
9
+ load_dotenv()
10
+
11
+ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
12
+
13
+ class VectorStore:
14
+ def __init__(self):
15
+ # ES 8.x 的连接配置
16
+ self.es = Elasticsearch(
17
+ "https://elastic.aixiao.xyz",
18
+ basic_auth=("elastic", os.getenv("PASSWORD")),
19
+ verify_certs=False,
20
+ request_timeout=30,
21
+ # 忽略系统索引警告
22
+ headers={"accept": "application/vnd.elasticsearch+json; compatible-with=8"},
23
+ )
24
+ self.api_key = os.getenv("API_KEY")
25
+ self.api_base = os.getenv("BASE_URL")
26
+
27
+ def get_embedding(self, text: str) -> List[float]:
28
+ """调用SiliconFlow的embedding API获取向量"""
29
+ headers = {
30
+ "Authorization": f"Bearer {self.api_key}",
31
+ "Content-Type": "application/json"
32
+ }
33
+
34
+ response = requests.post(
35
+ f"{self.api_base}/embeddings",
36
+ headers=headers,
37
+ json={
38
+ "model": "BAAI/bge-m3",
39
+ "input": text
40
+ }
41
+ )
42
+
43
+ if response.status_code == 200:
44
+ return response.json()["data"][0]["embedding"]
45
+ else:
46
+ raise Exception(f"Error getting embedding: {response.text}")
47
+
48
+ def store(self, documents: List[Dict], index_name: str) -> None:
49
+ """将文档存储到 Elasticsearch"""
50
+ # 创建索引(如果不存在)
51
+ if not self.es.indices.exists(index=index_name):
52
+ self.create_index(index_name)
53
+
54
+ # 获取当前索引中的文档数量
55
+ try:
56
+ response = self.es.count(index=index_name)
57
+ last_id = response['count'] - 1 # 文档数量减1作为最后的ID
58
+ if last_id < 0:
59
+ last_id = -1
60
+ except Exception as e:
61
+ print(f"获取文档数量时出错,假设为-1: {str(e)}")
62
+ last_id = -1
63
+
64
+ # 批量索引文档
65
+ bulk_data = []
66
+ for i, doc in enumerate(documents, start=last_id + 1):
67
+ # 获取文档向量
68
+ vector = self.get_embedding(doc['content'])
69
+
70
+ # 准备索引数据
71
+ bulk_data.append({
72
+ "index": {
73
+ "_index": index_name,
74
+ "_id": f"doc_{i}"
75
+ }
76
+ })
77
+
78
+ # 构建文档数据,包含新的img_url字段
79
+ doc_data = {
80
+ "content": doc['content'],
81
+ "vector": vector,
82
+ "metadata": {
83
+ "file_name": doc['metadata'].get('file_name', '未知文件'),
84
+ "source": doc['metadata'].get('source', ''),
85
+ "page": doc['metadata'].get('page', ''),
86
+ "img_url": doc['metadata'].get('img_url', '') # 添加img_url字段
87
+ }
88
+ }
89
+ bulk_data.append(doc_data)
90
+
91
+ # 批量写入
92
+ if bulk_data:
93
+ response = self.es.bulk(operations=bulk_data, refresh=True)
94
+ if response.get('errors'):
95
+ print("批量写入时出现错误:", response)
96
+
97
+ def get_files_in_index(self, index_name: str) -> List[str]:
98
+ """获取索引中的所有文件名"""
99
+ try:
100
+ response = self.es.search(
101
+ index=index_name,
102
+ body={
103
+ "size": 0,
104
+ "aggs": {
105
+ "unique_files": {
106
+ "terms": {
107
+ "field": "metadata.file_name",
108
+ "size": 1000
109
+ }
110
+ }
111
+ }
112
+ }
113
+ )
114
+
115
+ files = [bucket['key'] for bucket in response['aggregations']['unique_files']['buckets']]
116
+ return sorted(files)
117
+ except Exception as e:
118
+ print(f"获取文件列表时出错: {str(e)}")
119
+ return []
120
+
121
+ def create_index(self, index_name: str):
122
+ """创建 Elasticsearch 索引"""
123
+ settings = {
124
+ "mappings": {
125
+ "properties": {
126
+ "content": {"type": "text"},
127
+ "vector": {
128
+ "type": "dense_vector",
129
+ "dims": 1024
130
+ },
131
+ "metadata": {
132
+ "properties": {
133
+ "file_name": {
134
+ "type": "keyword",
135
+ "ignore_above": 256
136
+ },
137
+ "source": {
138
+ "type": "keyword"
139
+ },
140
+ "page": {
141
+ "type": "keyword"
142
+ },
143
+ "img_url": { # 新增图片URL字段
144
+ "type": "keyword",
145
+ "ignore_above": 2048
146
+ }
147
+ }
148
+ }
149
+ }
150
+ }
151
+ }
152
+
153
+ # 如果索引已存在,先删除
154
+ if self.es.indices.exists(index=index_name):
155
+ self.es.indices.delete(index=index_name)
156
+
157
+ self.es.indices.create(index=index_name, body=settings)
158
+
159
+ def delete_index(self, index_id: str) -> bool:
160
+ """删除一个索引"""
161
+ try:
162
+ if self.es.indices.exists(index=index_id):
163
+ self.es.indices.delete(index=index_id)
164
+ return True
165
+ return False
166
+ except Exception as e:
167
+ print(f"删除索引时出错: {str(e)}")
168
+ return False
169
+
170
+ def delete_document(self, index_id: str, file_name: str) -> bool:
171
+ """根据文件名删除文档"""
172
+ try:
173
+ response = self.es.delete_by_query(
174
+ index=index_id,
175
+ body={
176
+ "query": {
177
+ "term": {
178
+ "metadata.file_name": file_name
179
+ }
180
+ }
181
+ },
182
+ refresh=True
183
+ )
184
+ return True
185
+ except Exception as e:
186
+ print(f"删除文档时出错: {str(e)}")
187
  return False