diff --git a/apps/common/db/search.py b/apps/common/db/search.py index 76366715439..bef42a1414a 100644 --- a/apps/common/db/search.py +++ b/apps/common/db/search.py @@ -12,7 +12,7 @@ from django.db.models import QuerySet from common.db.compiler import AppSQLCompiler -from common.db.sql_execute import select_one, select_list +from common.db.sql_execute import select_one, select_list, update_execute from common.response.result import Page @@ -109,6 +109,24 @@ def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str, return select_list(exec_sql, exec_params) +def native_update(queryset: QuerySet | Dict[str, QuerySet], select_string: str, + field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None, + with_table_name=False): + """ + 复杂查询 + :param with_table_name: 生成sql是否包含表名 + :param queryset: 查询条件构造器 + :param select_string: 查询前缀 不包括 where limit 等信息 + :param field_replace_dict: 需要替换的字段 + :return: 查询结果 + """ + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name) + return update_execute(exec_sql, exec_params) + + def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler): """ 分页查询 diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 40ac4884ddb..5e52aa1efc7 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -9,24 +9,29 @@ import datetime import logging import os +import threading import traceback from typing import List import django.db.models +from django.db import models from django.db.models import QuerySet +from django.db.models.functions import Substr, Reverse from langchain_core.embeddings import Embeddings from common.config.embedding_config import VectorStore -from common.db.search import native_search, get_dynamics_model -from common.event.common import embedding_poxy +from common.db.search import native_search, get_dynamics_model, native_update +from common.db.sql_execute import sql_execute, update_execute from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock -from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping +from common.util.page_utils import page +from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State from embedding.models import SourceType, SearchMode from smartdoc.conf import PROJECT_DIR max_kb_error = logging.getLogger(__file__) max_kb = logging.getLogger(__file__) +lock = threading.Lock() class SyncWebDatasetArgs: @@ -114,7 +119,8 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings): @param embedding_model: 向量模型 """ max_kb.info(f"开始--->向量化段落:{paragraph_id}") - status = Status.success + # 更新到开始状态 + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, State.STARTED) try: data_list = native_search( {'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter( @@ -125,16 +131,22 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings): # 删除段落 VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) - def is_save_function(): - return QuerySet(Paragraph).filter(id=paragraph_id).exists() + def is_the_task_interrupted(): + _paragraph = QuerySet(Paragraph).filter(id=paragraph_id).first() + if _paragraph is None or Status(_paragraph.status)[TaskType.EMBEDDING] == State.REVOKE: + return True + return False # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_the_task_interrupted) + # 更新到开始状态 + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, + State.SUCCESS) except Exception as e: max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}') - status = Status.error + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, + State.FAILURE) finally: - QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status}) max_kb.info(f'结束--->向量化段落:{paragraph_id}') @staticmethod @@ -142,6 +154,91 @@ def embedding_by_data_list(data_list: List, embedding_model: Embeddings): # 批量向量化 VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True) + @staticmethod + def get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, post_apply=lambda: None): + def embedding_paragraph_apply(paragraph_list): + for paragraph in paragraph_list: + if is_the_task_interrupted(): + break + ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model) + post_apply() + + return embedding_paragraph_apply + + @staticmethod + def get_aggregation_document_status(document_id): + def aggregation_document_status(): + sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) + native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id), + 'default_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True) + + return aggregation_document_status + + @staticmethod + def get_aggregation_document_status_by_dataset_id(dataset_id): + def aggregation_document_status(): + sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) + native_update({'document_custom_sql': QuerySet(Document).filter(dataset_id=dataset_id), + 'default_sql': QuerySet(Document).filter(dataset_id=dataset_id)}, sql) + + return aggregation_document_status + + @staticmethod + def get_aggregation_document_status_by_query_set(queryset): + def aggregation_document_status(): + sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) + native_update({'document_custom_sql': queryset, 'default_sql': queryset}, sql) + + return aggregation_document_status + + @staticmethod + def post_update_document_status(document_id, task_type: TaskType): + _document = QuerySet(Document).filter(id=document_id).first() + + status = Status(_document.status) + if status[task_type] == State.REVOKE: + status[task_type] = State.REVOKED + else: + status[task_type] = State.SUCCESS + for item in _document.status_meta.get('aggs', []): + agg_status = item.get('status') + agg_count = item.get('count') + if Status(agg_status)[task_type] == State.FAILURE and agg_count > 0: + status[task_type] = State.FAILURE + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), task_type, status[task_type]) + + ListenerManagement.update_status(QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', task_type.value, + task_type.value), + ).filter(task_type_status=State.REVOKE.value).filter(document_id=document_id).values('id'), + task_type, + State.REVOKED) + + @staticmethod + def update_status(query_set: QuerySet, taskType: TaskType, state: State): + exec_sql = get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_paragraph_status.sql')) + bit_number = len(TaskType) + up_index = taskType.value - 1 + next_index = taskType.value + 1 + current_index = taskType.value + status_number = state.value + params_dict = {'${bit_number}': bit_number, '${up_index}': up_index, + '${status_number}': status_number, '${next_index}': next_index, + '${table_name}': query_set.model._meta.db_table, '${current_index}': current_index} + for key in params_dict: + _value_ = params_dict[key] + exec_sql = exec_sql.replace(key, str(_value_)) + lock.acquire() + try: + native_update(query_set, exec_sql) + finally: + lock.release() + @staticmethod def embedding_by_document(document_id, embedding_model: Embeddings): """ @@ -153,33 +250,29 @@ def embedding_by_document(document_id, embedding_model: Embeddings): if not try_lock('embedding' + str(document_id)): return max_kb.info(f"开始--->向量化文档:{document_id}") - QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding}) - QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding}) - status = Status.success + # 批量修改状态为PADDING + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.STARTED) try: - data_list = native_search( - {'problem': QuerySet( - get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter( - **{'paragraph.document_id': document_id}), - 'paragraph': QuerySet(Paragraph).filter(document_id=document_id)}, - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) # 删除文档向量数据 VectorStore.get_embedding_vector().delete_by_document_id(document_id) - def is_save_function(): - return QuerySet(Document).filter(id=document_id).exists() - - # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) + def is_the_task_interrupted(): + document = QuerySet(Document).filter(id=document_id).first() + if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE: + return True + return False + + # 根据段落进行向量化处理 + page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5, + ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, + ListenerManagement.get_aggregation_document_status( + document_id)), + is_the_task_interrupted) except Exception as e: max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') - status = Status.error finally: - # 修改状态 - QuerySet(Document).filter(id=document_id).update( - **{'status': status, 'update_time': datetime.datetime.now()}) - QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status}) + ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING) + ListenerManagement.get_aggregation_document_status(document_id)() max_kb.info(f"结束--->向量化文档:{document_id}") un_lock('embedding' + str(document_id)) diff --git a/apps/common/util/page_utils.py b/apps/common/util/page_utils.py new file mode 100644 index 00000000000..7fc176b6895 --- /dev/null +++ b/apps/common/util/page_utils.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: page_utils.py + @date:2024/11/21 10:32 + @desc: +""" +from math import ceil + + +def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False): + """ + + @param query_set: 查询query_set + @param page_size: 每次查询大小 + @param handler: 数据处理器 + @param is_the_task_interrupted: 任务是否被中断 + @return: + """ + count = query_set.count() + for i in range(0, ceil(count / page_size)): + if is_the_task_interrupted(): + return + offset = i * page_size + paragraph_list = query_set[offset: offset + page_size] + handler(paragraph_list) diff --git a/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py b/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py new file mode 100644 index 00000000000..02c9addd0f4 --- /dev/null +++ b/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py @@ -0,0 +1,34 @@ +# Generated by Django 4.2.15 on 2024-11-22 14:44 + +import dataset.models.data_set +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0010_file_meta'), + ] + + operations = [ + migrations.AddField( + model_name='document', + name='status_meta', + field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态统计数据'), + ), + migrations.AddField( + model_name='paragraph', + name='status_meta', + field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态数据'), + ), + migrations.AlterField( + model_name='document', + name='status', + field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'), + ), + migrations.AlterField( + model_name='paragraph', + name='status', + field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'), + ), + ] diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py index cd91b6d18c8..4f46eda2a09 100644 --- a/apps/dataset/models/data_set.py +++ b/apps/dataset/models/data_set.py @@ -7,6 +7,7 @@ @desc: 数据集 """ import uuid +from enum import Enum from django.db import models from django.db.models.signals import pre_delete @@ -18,13 +19,62 @@ from users.models import User -class Status(models.TextChoices): - """订单类型""" - embedding = 0, '导入中' - success = 1, '已完成' - error = 2, '导入失败' - queue_up = 3, '排队中' - generating = 4, '生成问题中' +class TaskType(Enum): + # 向量 + EMBEDDING = 1 + # 生成问题 + GENERATE_PROBLEM = 2 + # 同步 + SYNC = 3 + + +class State(Enum): + # 等待 + PENDING = '0' + # 执行中 + STARTED = '1' + # 成功 + SUCCESS = '2' + # 失败 + FAILURE = '3' + # 取消任务 + REVOKE = '4' + # 取消成功 + REVOKED = '5' + # 忽略 + IGNORED = 'n' + + +class Status: + type_cls = TaskType + state_cls = State + + def __init__(self, status: str = None): + self.task_status = {} + status_list = list(status[::-1] if status is not None else '') + for _type in self.type_cls: + index = _type.value - 1 + _state = self.state_cls(status_list[index] if len(status_list) > index else 'n') + self.task_status[_type] = _state + + @staticmethod + def of(status: str): + return Status(status) + + def __str__(self): + result = [] + for _type in sorted(self.type_cls, key=lambda item: item.value, reverse=True): + result.insert(len(self.type_cls) - _type.value, self.task_status[_type].value) + return ''.join(result) + + def __setitem__(self, key, value): + self.task_status[key] = value + + def __getitem__(self, item): + return self.task_status[item] + + def update_status(self, task_type: TaskType, state: State): + self.task_status[task_type] = state class Type(models.TextChoices): @@ -42,6 +92,10 @@ def default_model(): return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab') +def default_status_meta(): + return {"state_time": {}} + + class DataSet(AppModelMixin): """ 数据集表 @@ -68,8 +122,8 @@ class Document(AppModelMixin): dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) name = models.CharField(max_length=150, verbose_name="文档名称") char_length = models.IntegerField(verbose_name="文档字符数 冗余字段") - status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices, - default=Status.queue_up) + status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__) + status_meta = models.JSONField(verbose_name="状态统计数据", default=default_status_meta) is_active = models.BooleanField(default=True) type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices, @@ -94,8 +148,8 @@ class Paragraph(AppModelMixin): dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) content = models.CharField(max_length=102400, verbose_name="段落内容") title = models.CharField(max_length=256, verbose_name="标题", default="") - status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices, - default=Status.embedding) + status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__) + status_meta = models.JSONField(verbose_name="状态数据", default=default_status_meta) hit_num = models.IntegerField(verbose_name="命中次数", default=0) is_active = models.BooleanField(default=True) @@ -145,7 +199,6 @@ class File(AppModelMixin): meta = models.JSONField(verbose_name="文件关联数据", default=dict) - class Meta: db_table = "file" @@ -161,7 +214,6 @@ def get_byte(self): return result['data'] - @receiver(pre_delete, sender=File) def on_delete_file(sender, instance, **kwargs): select_one(f'SELECT lo_unlink({instance.loid})', []) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 7598432cc5e..85e73ee3243 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -27,6 +27,7 @@ from common.config.embedding_config import VectorStore from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.sql_execute import select_list +from common.event import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post, flat_map, valid_license @@ -34,7 +35,8 @@ from common.util.file_util import get_file_content from common.util.fork import ChildLink, Fork from common.util.split_model import get_split_model -from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status +from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status, \ + TaskType, State from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer @@ -733,9 +735,13 @@ def delete(self): def re_embedding(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - - QuerySet(Document).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up}) - QuerySet(Paragraph).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up}) + ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=self.data.get('id')), + TaskType.EMBEDDING, + State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).filter(dataset_id=self.data.get('id')), + TaskType.EMBEDDING, + State.PENDING) + ListenerManagement.get_aggregation_document_status_by_dataset_id(self.data.get('id'))() embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id')) embedding_by_dataset.delay(self.data.get('id'), embedding_model_id) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 61a6b02c4ff..1ab74ead244 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -19,6 +19,7 @@ from django.core import validators from django.db import transaction from django.db.models import QuerySet +from django.db.models.functions import Substr, Reverse from django.http import HttpResponse from drf_yasg import openapi from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE @@ -26,6 +27,7 @@ from xlwt import Utils from common.db.search import native_search, native_page_search +from common.event import ListenerManagement from common.event.common import work_thread_pool from common.exception.app_exception import AppApiException from common.handle.impl.doc_split_handle import DocSplitHandle @@ -44,7 +46,8 @@ from common.util.file_util import get_file_content from common.util.fork import Fork from common.util.split_model import get_split_model -from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image +from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Image, \ + TaskType, State from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \ get_embedding_model_id_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer @@ -67,6 +70,19 @@ def get_buffer(self, file): return self.buffer +class CancelInstanceSerializer(serializers.Serializer): + type = serializers.IntegerField(required=True, error_messages=ErrMessage.boolean( + "任务类型")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + _type = self.data.get('type') + try: + TaskType(_type) + except Exception as e: + raise AppApiException(500, '任务类型不支持') + + class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer): meta = serializers.DictField(required=False) name = serializers.CharField(required=False, max_length=128, min_length=1, @@ -278,7 +294,9 @@ def migrate(self, with_valid=True): # 修改向量信息 if model_id: delete_embedding_by_paragraph_ids(pid_list) - QuerySet(Document).filter(id__in=document_id_list).update(status=Status.queue_up) + ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list), + TaskType.EMBEDDING, + State.PENDING) embedding_by_document_list.delay(document_id_list, model_id) else: update_embedding_dataset_id(pid_list, target_dataset_id) @@ -404,11 +422,13 @@ def sync(self, with_valid=True, with_embedding=True): self.is_valid(raise_exception=True) document_id = self.data.get('document_id') document = QuerySet(Document).filter(id=document_id).first() + state = State.SUCCESS if document.type != Type.web: return True try: - document.status = Status.queue_up - document.save() + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.SYNC, + State.PENDING) source_url = document.meta.get('source_url') selector_list = document.meta.get('selector').split( " ") if 'selector' in document.meta and document.meta.get('selector') is not None else [] @@ -442,13 +462,18 @@ def sync(self, with_valid=True, with_embedding=True): if with_embedding: embedding_model_id = get_embedding_model_id_by_dataset_id(document.dataset_id) embedding_by_document.delay(document_id, embedding_model_id) + else: - document.status = Status.error - document.save() + state = State.FAILURE except Exception as e: logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') - document.status = Status.error - document.save() + state = State.FAILURE + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.SYNC, + state) + ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), + TaskType.SYNC, + state) return True class Operate(ApiMixin, serializers.Serializer): @@ -586,14 +611,35 @@ def refresh(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get("document_id") - QuerySet(Document).filter(id=document_id).update(**{'status': Status.queue_up}) - QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.queue_up}) + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, + State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), TaskType.EMBEDDING, + State.PENDING) + ListenerManagement.get_aggregation_document_status(document_id)() embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id')) try: embedding_by_document.delay(document_id, embedding_model_id) except AlreadyQueued as e: raise AppApiException(500, "任务正在执行中,请勿重复下发") + def cancel(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + CancelInstanceSerializer(data=instance).is_valid() + document_id = self.data.get("document_id") + ListenerManagement.update_status(QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, + TaskType(instance.get('type')).value), + ).filter(task_type_status__in=[State.PENDING.value, State.STARTED.value]).filter( + document_id=document_id).values('id'), + TaskType(instance.get('type')), + State.REVOKE) + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType(instance.get('type')), + State.REVOKE) + + return True + @transaction.atomic def delete(self): document_id = self.data.get("document_id") @@ -955,15 +1001,13 @@ def batch_refresh(self, instance: Dict, with_valid=True): self.is_valid(raise_exception=True) document_id_list = instance.get("id_list") with transaction.atomic(): - Document.objects.filter(id__in=document_id_list).update(status=Status.queue_up) - Paragraph.objects.filter(document_id__in=document_id_list).update(status=Status.queue_up) dataset_id = self.data.get('dataset_id') - embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=dataset_id) for document_id in document_id_list: try: - embedding_by_document.delay(document_id, embedding_model_id) + DocumentSerializers.Operate( + data={'dataset_id': dataset_id, 'document_id': document_id}).refresh() except AlreadyQueued as e: - raise AppApiException(500, "任务正在执行中,请勿重复下发") + pass class GenerateRelated(ApiMixin, serializers.Serializer): document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) @@ -978,7 +1022,13 @@ def generate_related(self, model_id, prompt, with_valid=True): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get('document_id') - QuerySet(Document).filter(id=document_id).update(status=Status.queue_up) + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), + TaskType.GENERATE_PROBLEM, + State.PENDING) + ListenerManagement.get_aggregation_document_status(document_id)() try: generate_related_by_document_id.delay(document_id, model_id, prompt) except AlreadyQueued as e: diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 6614d712ade..82aacc79d0b 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -16,11 +16,12 @@ from rest_framework import serializers from common.db.search import page_search +from common.event import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.field_message import ErrMessage -from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet +from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet, TaskType, State from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \ ProblemParagraphManage, get_embedding_model_id_by_dataset_id from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers @@ -722,7 +723,6 @@ def get_response_body_api(): } ) - class BatchGenerateRelated(ApiMixin, serializers.Serializer): dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) @@ -734,10 +734,16 @@ def batch_generate_related(self, instance: Dict, with_valid=True): paragraph_id_list = instance.get("paragraph_id_list") model_id = instance.get("model_id") prompt = instance.get("prompt") + document_id = self.data.get('document_id') + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).filter(id__in=paragraph_id_list), + TaskType.GENERATE_PROBLEM, + State.PENDING) + ListenerManagement.get_aggregation_document_status(document_id)() try: - generate_related_by_paragraph_id_list.delay(paragraph_id_list, model_id, prompt) + generate_related_by_paragraph_id_list.delay(document_id, paragraph_id_list, model_id, + prompt) except AlreadyQueued as e: raise AppApiException(500, "任务正在执行中,请勿重复下发") - - - diff --git a/apps/dataset/sql/list_document.sql b/apps/dataset/sql/list_document.sql index 818d783c834..c1e3a90370e 100644 --- a/apps/dataset/sql/list_document.sql +++ b/apps/dataset/sql/list_document.sql @@ -1,6 +1,7 @@ SELECT "document".* , to_json("document"."meta") as meta, + to_json("document"."status_meta") as status_meta, (SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count" FROM "document" "document" diff --git a/apps/dataset/sql/update_document_status_meta.sql b/apps/dataset/sql/update_document_status_meta.sql new file mode 100644 index 00000000000..ced642b83e6 --- /dev/null +++ b/apps/dataset/sql/update_document_status_meta.sql @@ -0,0 +1,25 @@ +UPDATE "document" "document" +SET status_meta = jsonb_set ( "document".status_meta, '{aggs}', tmp.status_meta ) +FROM + ( + SELECT COALESCE + ( jsonb_agg ( jsonb_delete ( ( row_to_json ( record ) :: JSONB ), 'document_id' ) ), '[]' :: JSONB ) AS status_meta, + document_id AS document_id + FROM + ( + SELECT + "paragraph".status, + "count" ( "paragraph"."id" ), + "document"."id" AS document_id + FROM + "document" "document" + LEFT JOIN "paragraph" "paragraph" ON "document"."id" = paragraph.document_id + ${document_custom_sql} + GROUP BY + "paragraph".status, + "document"."id" + ) record + GROUP BY + document_id + ) tmp +${default_sql} \ No newline at end of file diff --git a/apps/dataset/sql/update_paragraph_status.sql b/apps/dataset/sql/update_paragraph_status.sql new file mode 100644 index 00000000000..45f9c674fec --- /dev/null +++ b/apps/dataset/sql/update_paragraph_status.sql @@ -0,0 +1,13 @@ +UPDATE "${table_name}" +SET status = reverse ( + SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM 1 FOR ${up_index} ) || ${status_number} || SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM ${next_index} ) +), +status_meta = jsonb_set ( + "${table_name}".status_meta, + '{state_time,${current_index}}', + jsonb_set ( + COALESCE ( "${table_name}".status_meta #> '{state_time,${current_index}}', jsonb_build_object ( '${status_number}', now( ) ) ), + '{${status_number}}', + CONCAT ( '"', now( ), '"' ) :: JSONB + ) + ) \ No newline at end of file diff --git a/apps/dataset/swagger_api/document_api.py b/apps/dataset/swagger_api/document_api.py index 637a7e5098a..8fe588b7011 100644 --- a/apps/dataset/swagger_api/document_api.py +++ b/apps/dataset/swagger_api/document_api.py @@ -26,3 +26,14 @@ def get_request_body_api(): 'directly_return_similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="直接返回相似度") } ) + + class Cancel(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'type': openapi.Schema(type=openapi.TYPE_INTEGER, title="任务类型", + description="1|2|3 1:向量化|2:生成问题|3:同步文档") + } + ) diff --git a/apps/dataset/task/generate.py b/apps/dataset/task/generate.py index 860425978f5..e8103974461 100644 --- a/apps/dataset/task/generate.py +++ b/apps/dataset/task/generate.py @@ -1,12 +1,14 @@ import logging -from math import ceil +import traceback from celery_once import QueueOnce from django.db.models import QuerySet from langchain_core.messages import HumanMessage from common.config.embedding_config import ModelManage -from dataset.models import Paragraph, Document, Status +from common.event import ListenerManagement +from common.util.page_utils import page +from dataset.models import Paragraph, Document, Status, TaskType, State from dataset.task.tools import save_problem from ops import celery_app from setting.models import Model @@ -21,44 +23,79 @@ def get_llm_model(model_id): return ModelManage.get_model(model_id, lambda _id: get_model(model)) +def generate_problem_by_paragraph(paragraph, llm_model, prompt): + try: + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM, + State.STARTED) + res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))]) + if (res.content is None) or (len(res.content) == 0): + return + problems = res.content.split('\n') + for problem in problems: + save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem) + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM, + State.SUCCESS) + except Exception as e: + ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM, + State.FAILURE) + + +def get_generate_problem(llm_model, prompt, post_apply=lambda: None, is_the_task_interrupted=lambda: False): + def generate_problem(paragraph_list): + for paragraph in paragraph_list: + if is_the_task_interrupted(): + return + generate_problem_by_paragraph(paragraph, llm_model, prompt) + post_apply() + + return generate_problem + + @celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:generate_related_by_document') def generate_related_by_document_id(document_id, model_id, prompt): - llm_model = get_llm_model(model_id) - offset = 0 - page_size = 10 - QuerySet(Document).filter(id=document_id).update(status=Status.generating) - - count = QuerySet(Paragraph).filter(document_id=document_id).count() - for i in range(0, ceil(count / page_size)): - paragraph_list = QuerySet(Paragraph).filter(document_id=document_id).all()[offset:offset + page_size] - offset += page_size - for paragraph in paragraph_list: - res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))]) - if (res.content is None) or (len(res.content) == 0): - continue - problems = res.content.split('\n') - for problem in problems: - save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem) + try: + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.STARTED) + llm_model = get_llm_model(model_id) - QuerySet(Document).filter(id=document_id).update(status=Status.success) + def is_the_task_interrupted(): + document = QuerySet(Document).filter(id=document_id).first() + if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE: + return True + return False + # 生成问题函数 + generate_problem = get_generate_problem(llm_model, prompt, + ListenerManagement.get_aggregation_document_status( + document_id), is_the_task_interrupted) + page(QuerySet(Paragraph).filter(document_id=document_id), 10, generate_problem, is_the_task_interrupted) + except Exception as e: + max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}') + finally: + ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM) + max_kb.info(f"结束--->生成问题:{document_id}") @celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:generate_related_by_paragraph_list') -def generate_related_by_paragraph_id_list(paragraph_id_list, model_id, prompt): - llm_model = get_llm_model(model_id) - offset = 0 - page_size = 10 - count = QuerySet(Paragraph).filter(id__in=paragraph_id_list).count() - for i in range(0, ceil(count / page_size)): - paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list).all()[offset:offset + page_size] - offset += page_size - for paragraph in paragraph_list: - res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))]) - if (res.content is None) or (len(res.content) == 0): - continue - problems = res.content.split('\n') - for problem in problems: - save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem) +def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, prompt): + try: + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.STARTED) + llm_model = get_llm_model(model_id) + # 生成问题函数 + generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status( + document_id)) + + def is_the_task_interrupted(): + document = QuerySet(Document).filter(id=document_id).first() + if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE: + return True + return False + + page(QuerySet(Paragraph).filter(id__in=paragraph_id_list), 10, generate_problem, is_the_task_interrupted) + finally: + ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM) diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index b2246355601..9e5835318d6 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -37,6 +37,7 @@ name="document_export"), path('dataset//document//sync', views.Document.SyncWeb.as_view()), path('dataset//document//refresh', views.Document.Refresh.as_view()), + path('dataset//document//cancel_task', views.Document.CancelTask.as_view()), path('dataset//document//paragraph', views.Paragraph.as_view()), path('dataset//document/batch_generate_related', views.Document.BatchGenerateRelated.as_view()), path( @@ -45,7 +46,8 @@ path('dataset//document//paragraph/_batch', views.Paragraph.Batch.as_view()), path('dataset//document//paragraph//', views.Paragraph.Page.as_view(), name='paragraph_page'), - path('dataset//document//paragraph/batch_generate_related', views.Paragraph.BatchGenerateRelated.as_view()), + path('dataset//document//paragraph/batch_generate_related', + views.Paragraph.BatchGenerateRelated.as_view()), path('dataset//document//paragraph/', views.Paragraph.Operate.as_view()), path('dataset//document//paragraph//problem', diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index d911d0de867..4a98fb08bd2 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -218,6 +218,26 @@ def put(self, request: Request, dataset_id: str, document_id: str): DocumentSerializers.Sync(data={'document_id': document_id, 'dataset_id': dataset_id}).sync( )) + class CancelTask(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="取消任务", + operation_id="取消任务", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + request_body=DocumentApi.Cancel.get_request_body_api(), + responses=result.get_default_response(), + tags=["知识库/文档"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str): + return result.success( + DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).cancel( + request.data + )) + class Refresh(APIView): authentication_classes = [TokenAuth] diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index ab5ab4103a4..9a6eaff6b0b 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -86,20 +86,20 @@ def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, for child_array in result: self._batch_save(child_array, embedding, lambda: True) - def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function): + def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): """ 批量插入 @param data_list: 数据列表 @param embedding: 向量化处理器 - @param is_save_function: + @param is_the_task_interrupted: 判断是否中断任务 :return: bool """ self.save_pre_handler() chunk_list = chunk_data_list(data_list) result = sub_array(chunk_list) for child_array in result: - if is_save_function(): - self._batch_save(child_array, embedding, is_save_function) + if not is_the_task_interrupted(): + self._batch_save(child_array, embedding, is_the_task_interrupted) else: break @@ -110,7 +110,7 @@ def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str pass @abstractmethod - def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): pass def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 8cd2146ad9d..906da0cbdc9 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -57,7 +57,7 @@ def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str embedding.save() return True - def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted): texts = [row.get('text') for row in text_list] embeddings = embedding.embed_documents(texts) embedding_list = [Embedding(id=uuid.uuid1(), @@ -70,7 +70,7 @@ def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_func embedding=embeddings[index], search_vector=to_ts_vector(text_list[index]['text'])) for index in range(0, len(texts))] - if is_save_function(): + if not is_the_task_interrupted(): QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True diff --git a/apps/ops/celery/logger.py b/apps/ops/celery/logger.py index bdadc5685f9..1b2843c2b85 100644 --- a/apps/ops/celery/logger.py +++ b/apps/ops/celery/logger.py @@ -208,6 +208,7 @@ def flush(self): f.flush() def handle_task_start(self, task_id): + print('handle_task_start') log_path = get_celery_task_log_path(task_id) thread_id = self.get_current_thread_id() self.task_id_thread_id_mapper[task_id] = thread_id @@ -215,6 +216,7 @@ def handle_task_start(self, task_id): self.thread_id_fd_mapper[thread_id] = f def handle_task_end(self, task_id): + print('handle_task_end') ident_id = self.task_id_thread_id_mapper.get(task_id, '') f = self.thread_id_fd_mapper.pop(ident_id, None) if f and not f.closed: diff --git a/apps/ops/celery/signal_handler.py b/apps/ops/celery/signal_handler.py index 90ed62405f5..46671a0d8fa 100644 --- a/apps/ops/celery/signal_handler.py +++ b/apps/ops/celery/signal_handler.py @@ -5,7 +5,7 @@ from celery import subtask from celery.signals import ( - worker_ready, worker_shutdown, after_setup_logger + worker_ready, worker_shutdown, after_setup_logger, task_revoked, task_prerun ) from django.core.cache import cache from django_celery_beat.models import PeriodicTask @@ -61,3 +61,15 @@ def add_celery_logger_handler(sender=None, logger=None, loglevel=None, format=No formatter = logging.Formatter(format) task_handler.setFormatter(formatter) logger.addHandler(task_handler) + + +@task_revoked.connect +def on_task_revoked(request, terminated, signum, expired, **kwargs): + print('task_revoked', terminated) + + +@task_prerun.connect +def on_taskaa_start(sender, task_id, **kwargs): + pass + # sender.update_state(state='REVOKED', +# meta={'exc_type': 'Exception', 'exc': 'Exception', 'message': '暂停任务', 'exc_message': ''}) diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index 28954d0cc6b..7bd42546cc9 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -322,8 +322,17 @@ const batchGenerateRelated: ( data: any, loading?: Ref ) => Promise> = (dataset_id, data, loading) => { + return put(`${prefix}/${dataset_id}/document/batch_generate_related`, data, undefined, loading) +} + +const cancelTask: ( + dataset_id: string, + document_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, document_id, data, loading) => { return put( - `${prefix}/${dataset_id}/document/batch_generate_related`, + `${prefix}/${dataset_id}/document/${document_id}/cancel_task`, data, undefined, loading @@ -352,5 +361,6 @@ export default { postTableDocument, exportDocument, batchRefresh, - batchGenerateRelated + batchGenerateRelated, + cancelTask } diff --git a/ui/src/components/icons/index.ts b/ui/src/components/icons/index.ts index 9c1eb8c63bd..adbc6d8dbb8 100644 --- a/ui/src/components/icons/index.ts +++ b/ui/src/components/icons/index.ts @@ -1307,5 +1307,26 @@ export const iconMap: any = { ) ]) } + }, + 'app-close': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 16 16', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M7.96141 6.98572L12.4398 2.50738C12.5699 2.3772 12.781 2.3772 12.9112 2.50738L13.3826 2.97878C13.5127 3.10895 13.5127 3.32001 13.3826 3.45018L8.90422 7.92853L13.3826 12.4069C13.5127 12.537 13.5127 12.7481 13.3826 12.8783L12.9112 13.3497C12.781 13.4799 12.5699 13.4799 12.4398 13.3497L7.96141 8.87134L3.48307 13.3497C3.35289 13.4799 3.14184 13.4799 3.01166 13.3497L2.54026 12.8783C2.41008 12.7481 2.41008 12.537 2.54026 12.4069L7.0186 7.92853L2.54026 3.45018C2.41008 3.32001 2.41008 3.10895 2.54026 2.97878L3.01166 2.50738C3.14184 2.3772 3.35289 2.3772 3.48307 2.50738L7.96141 6.98572Z', + fill: 'currentColor' + }) + ] + ) + ]) + } } } diff --git a/ui/src/utils/status.ts b/ui/src/utils/status.ts new file mode 100644 index 00000000000..abfbd2e4ea8 --- /dev/null +++ b/ui/src/utils/status.ts @@ -0,0 +1,68 @@ +import { type Dict } from '@/api/type/common' +interface TaskTypeInterface { + // 向量化 + EMBEDDING: number + // 生成问题 + GENERATE_PROBLEM: number + // 同步 + SYNC: number +} +interface StateInterface { + // 等待 + PENDING: '0' + // 执行中 + STARTED: '1' + // 成功 + SUCCESS: '2' + // 失败 + FAILURE: '3' + // 取消任务 + REVOKE: '4' + // 取消成功 + REVOKED: '5' + IGNORED: 'n' +} +const TaskType: TaskTypeInterface = { + EMBEDDING: 1, + GENERATE_PROBLEM: 2, + SYNC: 3 +} +const State: StateInterface = { + // 等待 + PENDING: '0', + // 执行中 + STARTED: '1', + // 成功 + SUCCESS: '2', + // 失败 + FAILURE: '3', + // 取消任务 + REVOKE: '4', + // 取消成功 + REVOKED: '5', + IGNORED: 'n' +} +class Status { + task_status: Dict + constructor(status?: string) { + if (!status) { + status = '' + } + status = status.split('').reverse().join('') + this.task_status = {} + for (let key in TaskType) { + const value = TaskType[key as keyof TaskTypeInterface] + const index = value - 1 + this.task_status[value] = status[index] ? status[index] : 'n' + } + } + toString() { + const r = [] + for (let key in TaskType) { + const value = TaskType[key as keyof TaskTypeInterface] + r.push(this.task_status[value]) + } + return r.reverse().join('') + } +} +export { Status, State, TaskType, TaskTypeInterface, StateInterface } diff --git a/ui/src/views/document/component/Status.vue b/ui/src/views/document/component/Status.vue new file mode 100644 index 00000000000..8bbab6784aa --- /dev/null +++ b/ui/src/views/document/component/Status.vue @@ -0,0 +1,167 @@ + + + diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue index cd301600894..86042a3fa6b 100644 --- a/ui/src/views/document/index.vue +++ b/ui/src/views/document/index.vue @@ -134,21 +134,7 @@ @@ -249,7 +235,7 @@