Skip to content

Commit 021aa4a

Browse files
committed
feat: 文档状态
1 parent 119bba0 commit 021aa4a

24 files changed

+829
-136
lines changed

apps/common/db/search.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from django.db.models import QuerySet
1313

1414
from common.db.compiler import AppSQLCompiler
15-
from common.db.sql_execute import select_one, select_list
15+
from common.db.sql_execute import select_one, select_list, update_execute
1616
from common.response.result import Page
1717

1818

@@ -109,6 +109,24 @@ def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
109109
return select_list(exec_sql, exec_params)
110110

111111

112+
def native_update(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
113+
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
114+
with_table_name=False):
115+
"""
116+
复杂查询
117+
:param with_table_name: 生成sql是否包含表名
118+
:param queryset: 查询条件构造器
119+
:param select_string: 查询前缀 不包括 where limit 等信息
120+
:param field_replace_dict: 需要替换的字段
121+
:return: 查询结果
122+
"""
123+
if isinstance(queryset, Dict):
124+
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
125+
else:
126+
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
127+
return update_execute(exec_sql, exec_params)
128+
129+
112130
def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
113131
"""
114132
分页查询

apps/common/event/listener_manage.py

Lines changed: 122 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,29 @@
99
import datetime
1010
import logging
1111
import os
12+
import threading
1213
import traceback
1314
from typing import List
1415

1516
import django.db.models
17+
from django.db import models
1618
from django.db.models import QuerySet
19+
from django.db.models.functions import Substr, Reverse
1720
from langchain_core.embeddings import Embeddings
1821

1922
from common.config.embedding_config import VectorStore
20-
from common.db.search import native_search, get_dynamics_model
21-
from common.event.common import embedding_poxy
23+
from common.db.search import native_search, get_dynamics_model, native_update
24+
from common.db.sql_execute import sql_execute, update_execute
2225
from common.util.file_util import get_file_content
2326
from common.util.lock import try_lock, un_lock
24-
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
27+
from common.util.page_utils import page
28+
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State
2529
from embedding.models import SourceType, SearchMode
2630
from smartdoc.conf import PROJECT_DIR
2731

2832
max_kb_error = logging.getLogger(__file__)
2933
max_kb = logging.getLogger(__file__)
34+
lock = threading.Lock()
3035

3136

3237
class SyncWebDatasetArgs:
@@ -114,7 +119,8 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
114119
@param embedding_model: 向量模型
115120
"""
116121
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
117-
status = Status.success
122+
# 更新到开始状态
123+
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, State.STARTED)
118124
try:
119125
data_list = native_search(
120126
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
@@ -125,23 +131,114 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
125131
# 删除段落
126132
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
127133

128-
def is_save_function():
129-
return QuerySet(Paragraph).filter(id=paragraph_id).exists()
134+
def is_the_task_interrupted():
135+
_paragraph = QuerySet(Paragraph).filter(id=paragraph_id).first()
136+
if _paragraph is None or Status(_paragraph.status)[TaskType.EMBEDDING] == State.REVOKE:
137+
return True
138+
return False
130139

131140
# 批量向量化
132-
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
141+
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_the_task_interrupted)
142+
# 更新到开始状态
143+
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
144+
State.SUCCESS)
133145
except Exception as e:
134146
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
135-
status = Status.error
147+
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
148+
State.FAILURE)
136149
finally:
137-
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status})
138150
max_kb.info(f'结束--->向量化段落:{paragraph_id}')
139151

140152
@staticmethod
141153
def embedding_by_data_list(data_list: List, embedding_model: Embeddings):
142154
# 批量向量化
143155
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True)
144156

157+
@staticmethod
158+
def get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, post_apply=lambda: None):
159+
def embedding_paragraph_apply(paragraph_list):
160+
for paragraph in paragraph_list:
161+
if is_the_task_interrupted():
162+
break
163+
ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model)
164+
post_apply()
165+
166+
return embedding_paragraph_apply
167+
168+
@staticmethod
169+
def get_aggregation_document_status(document_id):
170+
def aggregation_document_status():
171+
sql = get_file_content(
172+
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
173+
native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id),
174+
'default_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True)
175+
176+
return aggregation_document_status
177+
178+
@staticmethod
179+
def get_aggregation_document_status_by_dataset_id(dataset_id):
180+
def aggregation_document_status():
181+
sql = get_file_content(
182+
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
183+
native_update({'document_custom_sql': QuerySet(Document).filter(dataset_id=dataset_id),
184+
'default_sql': QuerySet(Document).filter(dataset_id=dataset_id)}, sql)
185+
186+
return aggregation_document_status
187+
188+
@staticmethod
189+
def get_aggregation_document_status_by_query_set(queryset):
190+
def aggregation_document_status():
191+
sql = get_file_content(
192+
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
193+
native_update({'document_custom_sql': queryset, 'default_sql': queryset}, sql)
194+
195+
return aggregation_document_status
196+
197+
@staticmethod
198+
def post_update_document_status(document_id, task_type: TaskType):
199+
_document = QuerySet(Document).filter(id=document_id).first()
200+
201+
status = Status(_document.status)
202+
if status[task_type] == State.REVOKE:
203+
status[task_type] = State.REVOKED
204+
else:
205+
status[task_type] = State.SUCCESS
206+
for item in _document.status_meta.get('aggs', []):
207+
agg_status = item.get('status')
208+
agg_count = item.get('count')
209+
if Status(agg_status)[task_type] == State.FAILURE and agg_count > 0:
210+
status[task_type] = State.FAILURE
211+
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), task_type, status[task_type])
212+
213+
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
214+
reversed_status=Reverse('status'),
215+
task_type_status=Substr('reversed_status', task_type.value,
216+
task_type.value),
217+
).filter(task_type_status=State.REVOKE.value).filter(document_id=document_id).values('id'),
218+
task_type,
219+
State.REVOKED)
220+
221+
@staticmethod
222+
def update_status(query_set: QuerySet, taskType: TaskType, state: State):
223+
exec_sql = get_file_content(
224+
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_paragraph_status.sql'))
225+
bit_number = len(TaskType)
226+
up_index = taskType.value - 1
227+
next_index = taskType.value + 1
228+
current_index = taskType.value
229+
status_number = state.value
230+
params_dict = {'${bit_number}': bit_number, '${up_index}': up_index,
231+
'${status_number}': status_number, '${next_index}': next_index,
232+
'${table_name}': query_set.model._meta.db_table, '${current_index}': current_index}
233+
for key in params_dict:
234+
_value_ = params_dict[key]
235+
exec_sql = exec_sql.replace(key, str(_value_))
236+
lock.acquire()
237+
try:
238+
native_update(query_set, exec_sql)
239+
finally:
240+
lock.release()
241+
145242
@staticmethod
146243
def embedding_by_document(document_id, embedding_model: Embeddings):
147244
"""
@@ -153,33 +250,29 @@ def embedding_by_document(document_id, embedding_model: Embeddings):
153250
if not try_lock('embedding' + str(document_id)):
154251
return
155252
max_kb.info(f"开始--->向量化文档:{document_id}")
156-
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding})
157-
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding})
158-
status = Status.success
253+
# 批量修改状态为PADDING
254+
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.STARTED)
159255
try:
160-
data_list = native_search(
161-
{'problem': QuerySet(
162-
get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter(
163-
**{'paragraph.document_id': document_id}),
164-
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
165-
select_string=get_file_content(
166-
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
167256
# 删除文档向量数据
168257
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
169258

170-
def is_save_function():
171-
return QuerySet(Document).filter(id=document_id).exists()
172-
173-
# 批量向量化
174-
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
259+
def is_the_task_interrupted():
260+
document = QuerySet(Document).filter(id=document_id).first()
261+
if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE:
262+
return True
263+
return False
264+
265+
# 根据段落进行向量化处理
266+
page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5,
267+
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
268+
ListenerManagement.get_aggregation_document_status(
269+
document_id)),
270+
is_the_task_interrupted)
175271
except Exception as e:
176272
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
177-
status = Status.error
178273
finally:
179-
# 修改状态
180-
QuerySet(Document).filter(id=document_id).update(
181-
**{'status': status, 'update_time': datetime.datetime.now()})
182-
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
274+
ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING)
275+
ListenerManagement.get_aggregation_document_status(document_id)()
183276
max_kb.info(f"结束--->向量化文档:{document_id}")
184277
un_lock('embedding' + str(document_id))
185278

apps/common/util/page_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: page_utils.py
6+
@date:2024/11/21 10:32
7+
@desc:
8+
"""
9+
from math import ceil
10+
11+
12+
def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
13+
"""
14+
15+
@param query_set: 查询query_set
16+
@param page_size: 每次查询大小
17+
@param handler: 数据处理器
18+
@param is_the_task_interrupted: 任务是否被中断
19+
@return:
20+
"""
21+
count = query_set.count()
22+
for i in range(0, ceil(count / page_size)):
23+
if is_the_task_interrupted():
24+
return
25+
offset = i * page_size
26+
paragraph_list = query_set[offset: offset + page_size]
27+
handler(paragraph_list)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Generated by Django 4.2.15 on 2024-11-22 14:44
2+
3+
import dataset.models.data_set
4+
from django.db import migrations, models
5+
6+
7+
class Migration(migrations.Migration):
8+
9+
dependencies = [
10+
('dataset', '0010_file_meta'),
11+
]
12+
13+
operations = [
14+
migrations.AddField(
15+
model_name='document',
16+
name='status_meta',
17+
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态统计数据'),
18+
),
19+
migrations.AddField(
20+
model_name='paragraph',
21+
name='status_meta',
22+
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态数据'),
23+
),
24+
migrations.AlterField(
25+
model_name='document',
26+
name='status',
27+
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'),
28+
),
29+
migrations.AlterField(
30+
model_name='paragraph',
31+
name='status',
32+
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'),
33+
),
34+
]

0 commit comments

Comments
 (0)