samlax12 commited on
Commit
63b10db
·
verified ·
1 Parent(s): b4b1e3a

Update modules/knowledge_base/processor.py

Browse files
Files changed (1) hide show
  1. modules/knowledge_base/processor.py +229 -229
modules/knowledge_base/processor.py CHANGED
@@ -1,230 +1,230 @@
1
- from typing import List, Dict, Callable, Optional
2
- from langchain.text_splitter import RecursiveCharacterTextSplitter
3
- from langchain_community.document_loaders import (
4
- DirectoryLoader,
5
- UnstructuredMarkdownLoader,
6
- PyPDFLoader,
7
- TextLoader
8
- )
9
- import os
10
- import requests
11
- import base64
12
- from PIL import Image
13
- import io
14
-
15
- class DocumentLoader:
16
- """通用文档加载器"""
17
- def __init__(self, file_path: str, original_filename: str = None):
18
- self.file_path = file_path
19
- # 使用传入的原始文件名或者从路径提取的文件名
20
- self.original_filename = original_filename or os.path.basename(file_path)
21
- # 从原始文件名中获取扩展名,确保中文文件名也能正确识别文件类型
22
- self.extension = os.path.splitext(self.original_filename)[1].lower()
23
- self.api_key = os.getenv("API_KEY")
24
- self.api_base = os.getenv("BASE_URL")
25
-
26
- def process_image(self, image_path: str) -> str:
27
- """使用 SiliconFlow VLM 模型处理图片"""
28
- try:
29
- # 读取图片并转换为base64
30
- with open(image_path, 'rb') as image_file:
31
- image_data = image_file.read()
32
- base64_image = base64.b64encode(image_data).decode('utf-8')
33
-
34
- # 调用 SiliconFlow API
35
- headers = {
36
- "Authorization": f"Bearer {self.api_key}",
37
- "Content-Type": "application/json"
38
- }
39
-
40
- response = requests.post(
41
- f"{self.api_base}/chat/completions",
42
- headers=headers,
43
- json={
44
- "model": "Qwen/Qwen2.5-VL-72B-Instruct",
45
- "messages": [
46
- {
47
- "role": "user",
48
- "content": [
49
- {
50
- "type": "image_url",
51
- "image_url": {
52
- "url": f"data:image/jpeg;base64,{base64_image}",
53
- "detail": "high"
54
- }
55
- },
56
- {
57
- "type": "text",
58
- "text": "请详细描述这张图片的内容,包括主要对象、场景、活动、颜色、布局等关键信息。"
59
- }
60
- ]
61
- }
62
- ],
63
- "temperature": 0.7,
64
- "max_tokens": 500
65
- }
66
- )
67
-
68
- if response.status_code != 200:
69
- raise Exception(f"图片处理API调用失败: {response.text}")
70
-
71
- description = response.json()["choices"][0]["message"]["content"]
72
- return description
73
-
74
- except Exception as e:
75
- print(f"处理图片时出错: {str(e)}")
76
- return "图片处理失败"
77
-
78
- def load(self):
79
- try:
80
- print(f"正在加载文件: {self.file_path}, 原始文件名: {self.original_filename}, 扩展名: {self.extension}")
81
-
82
- if self.extension == '.md':
83
- try:
84
- loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8')
85
- return loader.load()
86
- except UnicodeDecodeError:
87
- # 如果UTF-8失败,尝试GBK
88
- loader = UnstructuredMarkdownLoader(self.file_path, encoding='gbk')
89
- return loader.load()
90
- elif self.extension == '.pdf':
91
- loader = PyPDFLoader(self.file_path)
92
- return loader.load()
93
- elif self.extension == '.txt':
94
- try:
95
- loader = TextLoader(self.file_path, encoding='utf-8')
96
- return loader.load()
97
- except UnicodeDecodeError:
98
- # 如果UTF-8失败,尝试GBK
99
- loader = TextLoader(self.file_path, encoding='gbk')
100
- return loader.load()
101
- elif self.extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']:
102
- # 处理图片
103
- description = self.process_image(self.file_path)
104
- # 创建一个包含图片描述的文档
105
- from langchain.schema import Document
106
- doc = Document(
107
- page_content=description,
108
- metadata={
109
- 'source': self.file_path,
110
- 'file_name': self.original_filename, # 使用原始文件名
111
- 'img_url': os.path.abspath(self.file_path) # 存储图片的绝对路径
112
- }
113
- )
114
- return [doc]
115
- else:
116
- print(f"不支持的文件扩展名: {self.extension}")
117
- raise ValueError(f"不支持的文件格式: {self.extension}")
118
-
119
- except UnicodeDecodeError:
120
- # 如果默认编码处理失败,尝试其他编码
121
- print(f"文件编码错误,尝试其他编码: {self.file_path}")
122
- if self.extension in ['.md', '.txt']:
123
- try:
124
- loader = TextLoader(self.file_path, encoding='gbk')
125
- return loader.load()
126
- except Exception as e:
127
- print(f"尝试GBK编码也失败: {str(e)}")
128
- raise
129
- except Exception as e:
130
- print(f"加载文件 {self.file_path} 时出错: {str(e)}")
131
- import traceback
132
- traceback.print_exc()
133
- raise
134
-
135
- class DocumentProcessor:
136
- def __init__(self):
137
- self.text_splitter = RecursiveCharacterTextSplitter(
138
- chunk_size=1000,
139
- chunk_overlap=200,
140
- length_function=len,
141
- )
142
-
143
- def get_index_name(self, path: str) -> str:
144
- """根据文件路径生成索引名称"""
145
- if os.path.isdir(path):
146
- # 如果是目录,使用目录名
147
- return f"rag_{os.path.basename(path).lower()}"
148
- else:
149
- # 如果是文件,使用文件名(不含扩展名)
150
- return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}"
151
-
152
- def process(self, path: str, progress_callback: Optional[Callable] = None, original_filename: str = None) -> List[Dict]:
153
- """
154
- 加载并处理文档,支持目录或单个文件
155
- 参数:
156
- path: 文档路径
157
- progress_callback: 进度回调函数,用于报告处理进度
158
- original_filename: 原始文件名(包括中文)
159
- 返回:处理后的文档列表
160
- """
161
- if os.path.isdir(path):
162
- documents = []
163
- total_files = sum([len(files) for _, _, files in os.walk(path)])
164
- processed_files = 0
165
- processed_size = 0
166
-
167
- for root, _, files in os.walk(path):
168
- for file in files:
169
- file_path = os.path.join(root, file)
170
- try:
171
- # 更新处理进度
172
- if progress_callback:
173
- file_size = os.path.getsize(file_path)
174
- processed_size += file_size
175
- processed_files += 1
176
- progress_callback(processed_size, f"处理文件 {processed_files}/{total_files}: {file}")
177
-
178
- # 为目录中的每个文件,传递原始文件名
179
- loader = DocumentLoader(file_path, original_filename=file)
180
- docs = loader.load()
181
- # 添加文件名到metadata
182
- for doc in docs:
183
- doc.metadata['file_name'] = file # 使用原始文件名
184
- documents.extend(docs)
185
- except Exception as e:
186
- print(f"警告:加载文件 {file_path} 时出错: {str(e)}")
187
- continue
188
- else:
189
- try:
190
- if progress_callback:
191
- file_size = os.path.getsize(path)
192
- progress_callback(file_size * 0.3, f"加载文件: {original_filename or os.path.basename(path)}")
193
-
194
- # 为单个文件,传递原始文件名
195
- loader = DocumentLoader(path, original_filename=original_filename)
196
- documents = loader.load()
197
-
198
- # 更新进度
199
- if progress_callback:
200
- progress_callback(file_size * 0.6, f"处理文件内容...")
201
-
202
- # 使用原始文件名而不是存储的文件名
203
- file_name = original_filename or os.path.basename(path)
204
- for doc in documents:
205
- doc.metadata['file_name'] = file_name
206
- except Exception as e:
207
- print(f"加载文件时出错: {str(e)}")
208
- raise
209
-
210
- # 分块
211
- chunks = self.text_splitter.split_documents(documents)
212
-
213
- # 更新进度
214
- if progress_callback:
215
- if os.path.isdir(path):
216
- progress_callback(processed_size, f"文档分块完成,共{len(chunks)}个文档片段")
217
- else:
218
- file_size = os.path.getsize(path)
219
- progress_callback(file_size * 0.9, f"文档分块完成,共{len(chunks)}个文档片段")
220
-
221
- # 处理成统一格式
222
- processed_docs = []
223
- for i, chunk in enumerate(chunks):
224
- processed_docs.append({
225
- 'id': f'doc_{i}',
226
- 'content': chunk.page_content,
227
- 'metadata': chunk.metadata
228
- })
229
-
230
  return processed_docs
 
1
+ from typing import List, Dict, Callable, Optional
2
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
+ from langchain_community.document_loaders import (
4
+ DirectoryLoader,
5
+ UnstructuredMarkdownLoader,
6
+ PyPDFLoader,
7
+ TextLoader
8
+ )
9
+ import os
10
+ import requests
11
+ import base64
12
+ from PIL import Image
13
+ import io
14
+
15
+ class DocumentLoader:
16
+ """通用文档加载器"""
17
+ def __init__(self, file_path: str, original_filename: str = None):
18
+ self.file_path = file_path
19
+ # 使用传入的原始文件名或者从路径提取的文件名
20
+ self.original_filename = original_filename or os.path.basename(file_path)
21
+ # 从原始文件名中获取扩展名,确保中文文件名也能正确识别文件类型
22
+ self.extension = os.path.splitext(self.original_filename)[1].lower()
23
+ self.api_key = os.getenv("API_KEY")
24
+ self.api_base = os.getenv("BASE_URL")
25
+
26
+ def process_image(self, image_path: str) -> str:
27
+ """使用 SiliconFlow VLM 模型处理图片"""
28
+ try:
29
+ # 读取图片并转换为base64
30
+ with open(image_path, 'rb') as image_file:
31
+ image_data = image_file.read()
32
+ base64_image = base64.b64encode(image_data).decode('utf-8')
33
+
34
+ # 调用 SiliconFlow API
35
+ headers = {
36
+ "Authorization": f"Bearer {self.api_key}",
37
+ "Content-Type": "application/json"
38
+ }
39
+
40
+ response = requests.post(
41
+ f"{self.api_base}/chat/completions",
42
+ headers=headers,
43
+ json={
44
+ "model": "Qwen/Qwen2.5-VL-72B-Instruct",
45
+ "messages": [
46
+ {
47
+ "role": "user",
48
+ "content": [
49
+ {
50
+ "type": "image_url",
51
+ "image_url": {
52
+ "url": f"data:image/jpeg;base64,{base64_image}",
53
+ "detail": "high"
54
+ }
55
+ },
56
+ {
57
+ "type": "text",
58
+ "text": "请详细描述这张图片的内容,包括主要对象、场景、活动、颜色、布局等关键信息。"
59
+ }
60
+ ]
61
+ }
62
+ ],
63
+ "temperature": 0.7,
64
+ "max_tokens": 500
65
+ }
66
+ )
67
+
68
+ if response.status_code != 200:
69
+ raise Exception(f"图片处理API调用失败: {response.text}")
70
+
71
+ description = response.json()["choices"][0]["message"]["content"]
72
+ return description
73
+
74
+ except Exception as e:
75
+ print(f"处理图片时出错: {str(e)}")
76
+ return "图片处理失败"
77
+
78
+ def load(self):
79
+ try:
80
+ print(f"正在加载文件: {self.file_path}, 原始文件名: {self.original_filename}, 扩展名: {self.extension}")
81
+
82
+ if self.extension == '.md':
83
+ try:
84
+ loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8')
85
+ return loader.load()
86
+ except UnicodeDecodeError:
87
+ # 如果UTF-8失败,尝试GBK
88
+ loader = UnstructuredMarkdownLoader(self.file_path, encoding='gbk')
89
+ return loader.load()
90
+ elif self.extension == '.pdf':
91
+ loader = PyPDFLoader(self.file_path)
92
+ return loader.load()
93
+ elif self.extension == '.txt':
94
+ try:
95
+ loader = TextLoader(self.file_path, encoding='utf-8')
96
+ return loader.load()
97
+ except UnicodeDecodeError:
98
+ # 如果UTF-8失败,尝试GBK
99
+ loader = TextLoader(self.file_path, encoding='gbk')
100
+ return loader.load()
101
+ elif self.extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']:
102
+ # 处理图片
103
+ description = self.process_image(self.file_path)
104
+ # 创建一个包含图片描述的文档
105
+ from langchain.schema import Document
106
+ doc = Document(
107
+ page_content=description,
108
+ metadata={
109
+ 'source': self.file_path,
110
+ 'file_name': self.original_filename, # 使用原始文件名
111
+ 'img_url': os.path.abspath(self.file_path) # 存储图片的绝对路径
112
+ }
113
+ )
114
+ return [doc]
115
+ else:
116
+ print(f"不支持的文件扩展名: {self.extension}")
117
+ raise ValueError(f"不支持的文件格式: {self.extension}")
118
+
119
+ except UnicodeDecodeError:
120
+ # 如果默认编码处理失败,尝试其他编码
121
+ print(f"文件编码错误,尝试其他编码: {self.file_path}")
122
+ if self.extension in ['.md', '.txt']:
123
+ try:
124
+ loader = TextLoader(self.file_path, encoding='gbk')
125
+ return loader.load()
126
+ except Exception as e:
127
+ print(f"尝试GBK编码也失败: {str(e)}")
128
+ raise
129
+ except Exception as e:
130
+ print(f"加载文件 {self.file_path} 时出错: {str(e)}")
131
+ import traceback
132
+ traceback.print_exc()
133
+ raise
134
+
135
+ class DocumentProcessor:
136
+ def __init__(self):
137
+ self.text_splitter = RecursiveCharacterTextSplitter(
138
+ chunk_size=1000,
139
+ chunk_overlap=200,
140
+ length_function=len,
141
+ )
142
+
143
+ def get_index_name(self, path: str) -> str:
144
+ """根据文件路径生成索引名称"""
145
+ if os.path.isdir(path):
146
+ # 如果是目录,使用目录名
147
+ return f"rag_{os.path.basename(path).lower()}"
148
+ else:
149
+ # 如果是文件,使用文件名(不含扩展名)
150
+ return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}"
151
+
152
+ def process(self, path: str, progress_callback: Optional[Callable] = None, original_filename: str = None) -> List[Dict]:
153
+ """
154
+ 加载并处理文档,支持目录或单个文件
155
+ 参数:
156
+ path: 文档路径
157
+ progress_callback: 进度回调函数,用于报告处理进度
158
+ original_filename: 原始文件名(包括中文)
159
+ 返回:处理后的文档列表
160
+ """
161
+ if os.path.isdir(path):
162
+ documents = []
163
+ total_files = sum([len(files) for _, _, files in os.walk(path)])
164
+ processed_files = 0
165
+ processed_size = 0
166
+
167
+ for root, _, files in os.walk(path):
168
+ for file in files:
169
+ file_path = os.path.join(root, file)
170
+ try:
171
+ # 更新处理进度
172
+ if progress_callback:
173
+ file_size = os.path.getsize(file_path)
174
+ processed_size += file_size
175
+ processed_files += 1
176
+ progress_callback(processed_size, f"处理文件 {processed_files}/{total_files}: {file}")
177
+
178
+ # 为目录中的每个文件,传递原始文件名
179
+ loader = DocumentLoader(file_path, original_filename=file)
180
+ docs = loader.load()
181
+ # 添加文件名到metadata
182
+ for doc in docs:
183
+ doc.metadata['file_name'] = file # 使用原始文件名
184
+ documents.extend(docs)
185
+ except Exception as e:
186
+ print(f"警告:加载文件 {file_path} 时出错: {str(e)}")
187
+ continue
188
+ else:
189
+ try:
190
+ if progress_callback:
191
+ file_size = os.path.getsize(path)
192
+ progress_callback(file_size * 0.3, f"加载文件: {original_filename or os.path.basename(path)}")
193
+
194
+ # 为单个文件,传递原始文件名
195
+ loader = DocumentLoader(path, original_filename=original_filename)
196
+ documents = loader.load()
197
+
198
+ # 更新进度
199
+ if progress_callback:
200
+ progress_callback(file_size * 0.6, f"处理文件内容...")
201
+
202
+ # 使用原始文件名而不是存储的文件名
203
+ file_name = original_filename or os.path.basename(path)
204
+ for doc in documents:
205
+ doc.metadata['file_name'] = file_name
206
+ except Exception as e:
207
+ print(f"加载文件时出错: {str(e)}")
208
+ raise
209
+
210
+ # 分块
211
+ chunks = self.text_splitter.split_documents(documents)
212
+
213
+ # 更新进度
214
+ if progress_callback:
215
+ if os.path.isdir(path):
216
+ progress_callback(processed_size, f"文档分块完成,共{len(chunks)}个文档片段")
217
+ else:
218
+ file_size = os.path.getsize(path)
219
+ progress_callback(file_size * 0.9, f"文档分块完成,共{len(chunks)}个文档片段")
220
+
221
+ # 处理成统一格式
222
+ processed_docs = []
223
+ for i, chunk in enumerate(chunks):
224
+ processed_docs.append({
225
+ 'id': f'doc_{i}',
226
+ 'content': chunk.page_content,
227
+ 'metadata': chunk.metadata
228
+ })
229
+
230
  return processed_docs