From 24a9fd36ca5e026fcc79ca2097a661bdcb21fedc Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Mon, 17 Feb 2025 15:48:10 +0800 Subject: [PATCH] feat: Generate problem support for generating unfinished paragraphs --- .../serializers/document_serializers.py | 26 +++++++++++------ apps/dataset/task/generate.py | 16 +++++++++-- .../xinference_model_provider/model/image.py | 17 ++++++++++- .../xinference_model_provider/model/llm.py | 17 ++++++++++- .../generate-related-dialog/index.vue | 28 +++++++++++++++++-- 5 files changed, 88 insertions(+), 16 deletions(-) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 811c5562410..265903c33fc 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -243,13 +243,16 @@ def export(self, with_valid=True): self.is_valid(raise_exception=True) language = get_language() if self.data.get('type') == 'csv': - file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'), "rb") + file = open( + os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'csv_template_{to_locale(language)}.csv'), + "rb") content = file.read() file.close() return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv', 'Content-Disposition': 'attachment; filename="csv_template.csv"'}) elif self.data.get('type') == 'excel': - file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'excel_template_{to_locale(language)}.xlsx'), "rb") + file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', + f'excel_template_{to_locale(language)}.xlsx'), "rb") content = file.read() file.close() return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel', @@ -261,7 +264,8 @@ def table_export(self, with_valid=True): language = get_language() if self.data.get('type') == 'csv': file = open( - os.path.join(PROJECT_DIR, "apps", "dataset", 'template', f'table_template_{to_locale(language)}.csv'), + os.path.join(PROJECT_DIR, "apps", "dataset", 'template', + f'table_template_{to_locale(language)}.csv'), "rb") content = file.read() file.close() @@ -1180,7 +1184,7 @@ def is_valid(self, *, raise_exception=False): if not QuerySet(Document).filter(id=document_id).exists(): raise AppApiException(500, _('document id not exist')) - def generate_related(self, model_id, prompt, with_valid=True): + def generate_related(self, model_id, prompt, state_list=None, with_valid=True): if with_valid: self.is_valid(raise_exception=True) document_id = self.data.get('document_id') @@ -1192,7 +1196,7 @@ def generate_related(self, model_id, prompt, with_valid=True): State.PENDING) ListenerManagement.get_aggregation_document_status(document_id)() try: - generate_related_by_document_id.delay(document_id, model_id, prompt) + generate_related_by_document_id.delay(document_id, model_id, prompt, state_list) except AlreadyQueued as e: raise AppApiException(500, _('The task is being executed, please do not send it again.')) @@ -1205,17 +1209,23 @@ def batch_generate_related(self, instance: Dict, with_valid=True): document_id_list = instance.get("document_id_list") model_id = instance.get("model_id") prompt = instance.get("prompt") + state_list = instance.get('state_list') ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list), TaskType.GENERATE_PROBLEM, State.PENDING) - ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id__in=document_id_list), - TaskType.GENERATE_PROBLEM, + ListenerManagement.update_status(QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value, + 1), + ).filter(task_type_status__in=state_list, document_id__in=document_id_list) + .values('id'), + TaskType.EMBEDDING, State.PENDING) ListenerManagement.get_aggregation_document_status_by_query_set( QuerySet(Document).filter(id__in=document_id_list))() try: for document_id in document_id_list: - generate_related_by_document_id.delay(document_id, model_id, prompt) + generate_related_by_document_id.delay(document_id, model_id, prompt, state_list) except AlreadyQueued as e: pass diff --git a/apps/dataset/task/generate.py b/apps/dataset/task/generate.py index 5ffcd1bec65..bf9e53869a1 100644 --- a/apps/dataset/task/generate.py +++ b/apps/dataset/task/generate.py @@ -3,11 +3,12 @@ from celery_once import QueueOnce from django.db.models import QuerySet +from django.db.models.functions import Reverse, Substr from langchain_core.messages import HumanMessage from common.config.embedding_config import ModelManage from common.event import ListenerManagement -from common.util.page_utils import page +from common.util.page_utils import page, page_desc from dataset.models import Paragraph, Document, Status, TaskType, State from dataset.task.tools import save_problem from ops import celery_app @@ -64,7 +65,11 @@ def is_the_task_interrupted(): @celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:generate_related_by_document') -def generate_related_by_document_id(document_id, model_id, prompt): +def generate_related_by_document_id(document_id, model_id, prompt, state_list=None): + if state_list is None: + state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value, + State.REVOKE.value, + State.REVOKED.value, State.IGNORED.value] try: is_the_task_interrupted = get_is_the_task_interrupted(document_id) if is_the_task_interrupted(): @@ -78,7 +83,12 @@ def generate_related_by_document_id(document_id, model_id, prompt): generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status( document_id), is_the_task_interrupted) - page(QuerySet(Paragraph).filter(document_id=document_id), 10, generate_problem, is_the_task_interrupted) + query_set = QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value, + 1), + ).filter(task_type_status__in=state_list, document_id=document_id) + page_desc(query_set, 10, generate_problem, is_the_task_interrupted) except Exception as e: max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}') max_kb_error.error(_('Generate issue based on document: {document_id} error {error}{traceback}').format( diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/image.py b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py index f51a64ec41d..a195b86491b 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/image.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py @@ -1,5 +1,8 @@ -from typing import Dict +from typing import Dict, List +from langchain_core.messages import BaseMessage, get_buffer_string + +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI @@ -18,3 +21,15 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** stream_usage=True, **optional_params, ) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + if self.usage_metadata is None or self.usage_metadata == {}: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.usage_metadata.get('input_tokens', 0) + + def get_num_tokens(self, text: str) -> int: + if self.usage_metadata is None or self.usage_metadata == {}: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('output_tokens', 0) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py index 42e098aa393..d76979bd3a3 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py @@ -1,8 +1,11 @@ # coding=utf-8 -from typing import Dict +from typing import Dict, List from urllib.parse import urlparse, ParseResult +from langchain_core.messages import BaseMessage, get_buffer_string + +from common.config.tokenizer_manage_config import TokenizerManage from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI @@ -33,3 +36,15 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** openai_api_key=model_credential.get('api_key'), **optional_params ) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + if self.usage_metadata is None or self.usage_metadata == {}: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.usage_metadata.get('input_tokens', 0) + + def get_num_tokens(self, text: str) -> int: + if self.usage_metadata is None or self.usage_metadata == {}: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('output_tokens', 0) diff --git a/ui/src/components/generate-related-dialog/index.vue b/ui/src/components/generate-related-dialog/index.vue index 8c3a3e1d370..4d485eae0bb 100644 --- a/ui/src/components/generate-related-dialog/index.vue +++ b/ui/src/components/generate-related-dialog/index.vue @@ -48,6 +48,16 @@ type="textarea" /> + + + {{ + $t('views.document.form.selectVectorization.error') + }} + {{ + $t('views.document.form.selectVectorization.all') + }} + +