9
9
import datetime
10
10
import logging
11
11
import os
12
+ import threading
12
13
import traceback
13
14
from typing import List
14
15
15
16
import django .db .models
17
+ from django .db import models
16
18
from django .db .models import QuerySet
19
+ from django .db .models .functions import Substr , Reverse
17
20
from langchain_core .embeddings import Embeddings
18
21
19
22
from common .config .embedding_config import VectorStore
20
- from common .db .search import native_search , get_dynamics_model
21
- from common .event . common import embedding_poxy
23
+ from common .db .search import native_search , get_dynamics_model , native_update
24
+ from common .db . sql_execute import sql_execute , update_execute
22
25
from common .util .file_util import get_file_content
23
26
from common .util .lock import try_lock , un_lock
24
- from dataset .models import Paragraph , Status , Document , ProblemParagraphMapping
27
+ from common .util .page_utils import page
28
+ from dataset .models import Paragraph , Status , Document , ProblemParagraphMapping , TaskType , State
25
29
from embedding .models import SourceType , SearchMode
26
30
from smartdoc .conf import PROJECT_DIR
27
31
28
32
max_kb_error = logging .getLogger (__file__ )
29
33
max_kb = logging .getLogger (__file__ )
34
+ lock = threading .Lock ()
30
35
31
36
32
37
class SyncWebDatasetArgs :
@@ -114,7 +119,8 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
114
119
@param embedding_model: 向量模型
115
120
"""
116
121
max_kb .info (f"开始--->向量化段落:{ paragraph_id } " )
117
- status = Status .success
122
+ # 更新到开始状态
123
+ ListenerManagement .update_status (QuerySet (Paragraph ).filter (id = paragraph_id ), TaskType .EMBEDDING , State .STARTED )
118
124
try :
119
125
data_list = native_search (
120
126
{'problem' : QuerySet (get_dynamics_model ({'paragraph.id' : django .db .models .CharField ()})).filter (
@@ -125,23 +131,114 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
125
131
# 删除段落
126
132
VectorStore .get_embedding_vector ().delete_by_paragraph_id (paragraph_id )
127
133
128
- def is_save_function ():
129
- return QuerySet (Paragraph ).filter (id = paragraph_id ).exists ()
134
+ def is_the_task_interrupted ():
135
+ _paragraph = QuerySet (Paragraph ).filter (id = paragraph_id ).first ()
136
+ if _paragraph is None or Status (_paragraph .status )[TaskType .EMBEDDING ] == State .REVOKE :
137
+ return True
138
+ return False
130
139
131
140
# 批量向量化
132
- VectorStore .get_embedding_vector ().batch_save (data_list , embedding_model , is_save_function )
141
+ VectorStore .get_embedding_vector ().batch_save (data_list , embedding_model , is_the_task_interrupted )
142
+ # 更新到开始状态
143
+ ListenerManagement .update_status (QuerySet (Paragraph ).filter (id = paragraph_id ), TaskType .EMBEDDING ,
144
+ State .SUCCESS )
133
145
except Exception as e :
134
146
max_kb_error .error (f'向量化段落:{ paragraph_id } 出现错误{ str (e )} { traceback .format_exc ()} ' )
135
- status = Status .error
147
+ ListenerManagement .update_status (QuerySet (Paragraph ).filter (id = paragraph_id ), TaskType .EMBEDDING ,
148
+ State .FAILURE )
136
149
finally :
137
- QuerySet (Paragraph ).filter (id = paragraph_id ).update (** {'status' : status })
138
150
max_kb .info (f'结束--->向量化段落:{ paragraph_id } ' )
139
151
140
152
@staticmethod
141
153
def embedding_by_data_list (data_list : List , embedding_model : Embeddings ):
142
154
# 批量向量化
143
155
VectorStore .get_embedding_vector ().batch_save (data_list , embedding_model , lambda : True )
144
156
157
+ @staticmethod
158
+ def get_embedding_paragraph_apply (embedding_model , is_the_task_interrupted , post_apply = lambda : None ):
159
+ def embedding_paragraph_apply (paragraph_list ):
160
+ for paragraph in paragraph_list :
161
+ if is_the_task_interrupted ():
162
+ break
163
+ ListenerManagement .embedding_by_paragraph (str (paragraph .get ('id' )), embedding_model )
164
+ post_apply ()
165
+
166
+ return embedding_paragraph_apply
167
+
168
+ @staticmethod
169
+ def get_aggregation_document_status (document_id ):
170
+ def aggregation_document_status ():
171
+ sql = get_file_content (
172
+ os .path .join (PROJECT_DIR , "apps" , "dataset" , 'sql' , 'update_document_status_meta.sql' ))
173
+ native_update ({'document_custom_sql' : QuerySet (Document ).filter (id = document_id ),
174
+ 'default_sql' : QuerySet (Document ).filter (id = document_id )}, sql , with_table_name = True )
175
+
176
+ return aggregation_document_status
177
+
178
+ @staticmethod
179
+ def get_aggregation_document_status_by_dataset_id (dataset_id ):
180
+ def aggregation_document_status ():
181
+ sql = get_file_content (
182
+ os .path .join (PROJECT_DIR , "apps" , "dataset" , 'sql' , 'update_document_status_meta.sql' ))
183
+ native_update ({'document_custom_sql' : QuerySet (Document ).filter (dataset_id = dataset_id ),
184
+ 'default_sql' : QuerySet (Document ).filter (dataset_id = dataset_id )}, sql )
185
+
186
+ return aggregation_document_status
187
+
188
+ @staticmethod
189
+ def get_aggregation_document_status_by_query_set (queryset ):
190
+ def aggregation_document_status ():
191
+ sql = get_file_content (
192
+ os .path .join (PROJECT_DIR , "apps" , "dataset" , 'sql' , 'update_document_status_meta.sql' ))
193
+ native_update ({'document_custom_sql' : queryset , 'default_sql' : queryset }, sql )
194
+
195
+ return aggregation_document_status
196
+
197
+ @staticmethod
198
+ def post_update_document_status (document_id , task_type : TaskType ):
199
+ _document = QuerySet (Document ).filter (id = document_id ).first ()
200
+
201
+ status = Status (_document .status )
202
+ if status [task_type ] == State .REVOKE :
203
+ status [task_type ] = State .REVOKED
204
+ else :
205
+ status [task_type ] = State .SUCCESS
206
+ for item in _document .status_meta .get ('aggs' , []):
207
+ agg_status = item .get ('status' )
208
+ agg_count = item .get ('count' )
209
+ if Status (agg_status )[task_type ] == State .FAILURE and agg_count > 0 :
210
+ status [task_type ] = State .FAILURE
211
+ ListenerManagement .update_status (QuerySet (Document ).filter (id = document_id ), task_type , status [task_type ])
212
+
213
+ ListenerManagement .update_status (QuerySet (Paragraph ).annotate (
214
+ reversed_status = Reverse ('status' ),
215
+ task_type_status = Substr ('reversed_status' , task_type .value ,
216
+ task_type .value ),
217
+ ).filter (task_type_status = State .REVOKE .value ).filter (document_id = document_id ).values ('id' ),
218
+ task_type ,
219
+ State .REVOKED )
220
+
221
+ @staticmethod
222
+ def update_status (query_set : QuerySet , taskType : TaskType , state : State ):
223
+ exec_sql = get_file_content (
224
+ os .path .join (PROJECT_DIR , "apps" , "dataset" , 'sql' , 'update_paragraph_status.sql' ))
225
+ bit_number = len (TaskType )
226
+ up_index = taskType .value - 1
227
+ next_index = taskType .value + 1
228
+ current_index = taskType .value
229
+ status_number = state .value
230
+ params_dict = {'${bit_number}' : bit_number , '${up_index}' : up_index ,
231
+ '${status_number}' : status_number , '${next_index}' : next_index ,
232
+ '${table_name}' : query_set .model ._meta .db_table , '${current_index}' : current_index }
233
+ for key in params_dict :
234
+ _value_ = params_dict [key ]
235
+ exec_sql = exec_sql .replace (key , str (_value_ ))
236
+ lock .acquire ()
237
+ try :
238
+ native_update (query_set , exec_sql )
239
+ finally :
240
+ lock .release ()
241
+
145
242
@staticmethod
146
243
def embedding_by_document (document_id , embedding_model : Embeddings ):
147
244
"""
@@ -153,33 +250,29 @@ def embedding_by_document(document_id, embedding_model: Embeddings):
153
250
if not try_lock ('embedding' + str (document_id )):
154
251
return
155
252
max_kb .info (f"开始--->向量化文档:{ document_id } " )
156
- QuerySet (Document ).filter (id = document_id ).update (** {'status' : Status .embedding })
157
- QuerySet (Paragraph ).filter (document_id = document_id ).update (** {'status' : Status .embedding })
158
- status = Status .success
253
+ # 批量修改状态为PADDING
254
+ ListenerManagement .update_status (QuerySet (Document ).filter (id = document_id ), TaskType .EMBEDDING , State .STARTED )
159
255
try :
160
- data_list = native_search (
161
- {'problem' : QuerySet (
162
- get_dynamics_model ({'paragraph.document_id' : django .db .models .CharField ()})).filter (
163
- ** {'paragraph.document_id' : document_id }),
164
- 'paragraph' : QuerySet (Paragraph ).filter (document_id = document_id )},
165
- select_string = get_file_content (
166
- os .path .join (PROJECT_DIR , "apps" , "common" , 'sql' , 'list_embedding_text.sql' )))
167
256
# 删除文档向量数据
168
257
VectorStore .get_embedding_vector ().delete_by_document_id (document_id )
169
258
170
- def is_save_function ():
171
- return QuerySet (Document ).filter (id = document_id ).exists ()
172
-
173
- # 批量向量化
174
- VectorStore .get_embedding_vector ().batch_save (data_list , embedding_model , is_save_function )
259
+ def is_the_task_interrupted ():
260
+ document = QuerySet (Document ).filter (id = document_id ).first ()
261
+ if document is None or Status (document .status )[TaskType .EMBEDDING ] == State .REVOKE :
262
+ return True
263
+ return False
264
+
265
+ # 根据段落进行向量化处理
266
+ page (QuerySet (Paragraph ).filter (document_id = document_id ).values ('id' ), 5 ,
267
+ ListenerManagement .get_embedding_paragraph_apply (embedding_model , is_the_task_interrupted ,
268
+ ListenerManagement .get_aggregation_document_status (
269
+ document_id )),
270
+ is_the_task_interrupted )
175
271
except Exception as e :
176
272
max_kb_error .error (f'向量化文档:{ document_id } 出现错误{ str (e )} { traceback .format_exc ()} ' )
177
- status = Status .error
178
273
finally :
179
- # 修改状态
180
- QuerySet (Document ).filter (id = document_id ).update (
181
- ** {'status' : status , 'update_time' : datetime .datetime .now ()})
182
- QuerySet (Paragraph ).filter (document_id = document_id ).update (** {'status' : status })
274
+ ListenerManagement .post_update_document_status (document_id , TaskType .EMBEDDING )
275
+ ListenerManagement .get_aggregation_document_status (document_id )()
183
276
max_kb .info (f"结束--->向量化文档:{ document_id } " )
184
277
un_lock ('embedding' + str (document_id ))
185
278
0 commit comments