diff --git a/apps/application/chat_pipeline/I_base_chat_pipeline.py b/apps/application/chat_pipeline/I_base_chat_pipeline.py index 8b796a7b45d..91effa82c0e 100644 --- a/apps/application/chat_pipeline/I_base_chat_pipeline.py +++ b/apps/application/chat_pipeline/I_base_chat_pipeline.py @@ -18,7 +18,8 @@ class ParagraphPipelineModel: def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str, - is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str): + is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str, + hit_handling_method: str): self.id = _id self.document_id = document_id self.dataset_id = dataset_id @@ -30,6 +31,7 @@ def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, ti self.similarity = similarity self.dataset_name = dataset_name self.document_name = document_name + self.hit_handling_method = hit_handling_method def to_dict(self): return { @@ -53,6 +55,7 @@ def __init__(self): self.comprehensive_score = None self.document_name = None self.dataset_name = None + self.hit_handling_method = None def add_paragraph(self, paragraph): if isinstance(paragraph, Paragraph): @@ -76,6 +79,10 @@ def add_document_name(self, document_name): self.document_name = document_name return self + def add_hit_handling_method(self, hit_handling_method): + self.hit_handling_method = hit_handling_method + return self + def add_comprehensive_score(self, comprehensive_score: float): self.comprehensive_score = comprehensive_score return self @@ -91,7 +98,7 @@ def build(self): self.paragraph.get('status'), self.paragraph.get('is_active'), self.comprehensive_score, self.similarity, self.dataset_name, - self.document_name) + self.document_name, self.hit_handling_method) class IBaseChatPipelineStep: diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 85485706dda..f5c50bc541d 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -146,8 +146,17 @@ def execute_stream(self, message_list: List[BaseMessage], 'status') == 'designated_answer': chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))]) else: - chat_result = chat_model.stream(message_list) - is_ai_chat = True + if paragraph_list is not None and len(paragraph_list) > 0: + directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) + for paragraph in paragraph_list if + paragraph.hit_handling_method == 'directly_return'] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + chat_result = iter(directly_return_chunk_list) + else: + chat_result = chat_model.stream(message_list) + else: + chat_result = chat_model.stream(message_list) + is_ai_chat = True chat_record_id = uuid.uuid1() r = StreamingHttpResponse( diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index dcd375ce4ec..bfc1118e848 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -35,8 +35,9 @@ def execute(self, problem_text: str, dataset_id_list: list[str], exclude_documen exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode)) if embedding_list is None: return [] - paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector) - return [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] + paragraph_list = self.list_paragraph(embedding_list, vector) + result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] + return result @staticmethod def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineModel: @@ -50,10 +51,21 @@ def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineM .add_comprehensive_score(find_embedding.get('comprehensive_score')) .add_dataset_name(paragraph.get('dataset_name')) .add_document_name(paragraph.get('document_name')) + .add_hit_handling_method(paragraph.get('hit_handling_method')) .build()) @staticmethod - def list_paragraph(paragraph_id_list: List, vector): + def get_similarity(paragraph, embedding_list: List): + filter_embedding_list = [embedding for embedding in embedding_list if + str(embedding.get('paragraph_id')) == str(paragraph.get('id'))] + if filter_embedding_list is not None and len(filter_embedding_list) > 0: + find_embedding = filter_embedding_list[-1] + return find_embedding.get('comprehensive_score') + return 0 + + @staticmethod + def list_paragraph(embedding_list: List, vector): + paragraph_id_list = [row.get('paragraph_id') for row in embedding_list] if paragraph_id_list is None or len(paragraph_id_list) == 0: return [] paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list), @@ -67,6 +79,13 @@ def list_paragraph(paragraph_id_list: List, vector): for paragraph_id in paragraph_id_list: if not exist_paragraph_list.__contains__(paragraph_id): vector.delete_by_paragraph_id(paragraph_id) + # 如果存在直接返回的则取直接返回段落 + hit_handling_method_paragraph = [paragraph for paragraph in paragraph_list if + paragraph.get('hit_handling_method') == 'directly_return'] + if len(hit_handling_method_paragraph) > 0: + # 找到评分最高的 + return [sorted(hit_handling_method_paragraph, + key=lambda p: BaseSearchDatasetStep.get_similarity(p, embedding_list))[-1]] return paragraph_list def get_details(self, manage, **kwargs): diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index ddfea22af61..47b905a77b9 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -196,9 +196,11 @@ def chat(self): exclude_paragraph_id_list = [] # 相同问题是否需要排除已经查询到的段落 if re_chat: - paragraph_id_list = flat_map([row.paragraph_id_list for row in - filter(lambda chat_record: chat_record == message, - chat_info.chat_record_list)]) + paragraph_id_list = flat_map( + [[paragraph.get('id') for paragraph in chat_record.details['search_step']['paragraph_list']] for + chat_record in chat_info.chat_record_list if + chat_record.problem_text == message and 'search_step' in chat_record.details and 'paragraph_list' in + chat_record.details['search_step']]) exclude_paragraph_id_list = list(set(paragraph_id_list)) # 构建运行参数 params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list, diff --git a/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql b/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql index b0843b45216..813d4f090ec 100644 --- a/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql +++ b/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql @@ -1,7 +1,8 @@ SELECT paragraph.*, dataset."name" AS "dataset_name", - "document"."name" AS "document_name" + "document"."name" AS "document_name", + "document"."hit_handling_method" AS "hit_handling_method" FROM paragraph paragraph LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id diff --git a/apps/dataset/migrations/0003_document_hit_handling_method.py b/apps/dataset/migrations/0003_document_hit_handling_method.py new file mode 100644 index 00000000000..e1746d61156 --- /dev/null +++ b/apps/dataset/migrations/0003_document_hit_handling_method.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-04-24 15:36 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0002_image'), + ] + + operations = [ + migrations.AddField( + model_name='document', + name='hit_handling_method', + field=models.CharField(choices=[('optimization', '模型优化'), ('directly_return', '直接返回')], default='optimization', max_length=20, verbose_name='命中处理方式'), + ), + ] diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py index 9ee76ffe220..ab1cfa6f166 100644 --- a/apps/dataset/models/data_set.py +++ b/apps/dataset/models/data_set.py @@ -27,6 +27,11 @@ class Type(models.TextChoices): web = 1, 'web站点类型' +class HitHandlingMethod(models.TextChoices): + optimization = 'optimization', '模型优化' + directly_return = 'directly_return', '直接返回' + + class DataSet(AppModelMixin): """ 数据集表 @@ -58,6 +63,9 @@ class Document(AppModelMixin): type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices, default=Type.base) + hit_handling_method = models.CharField(verbose_name='命中处理方式', max_length=20, + choices=HitHandlingMethod.choices, + default=HitHandlingMethod.optimization) meta = models.JSONField(verbose_name="元数据", default=dict) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index c3b41e802b5..9fead4a3c51 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -8,11 +8,13 @@ """ import logging import os +import re import traceback import uuid from functools import reduce from typing import List, Dict +from django.core import validators from django.db import transaction from django.db.models import QuerySet from drf_yasg import openapi @@ -42,6 +44,12 @@ class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer): name = serializers.CharField(required=False, max_length=128, min_length=1, error_messages=ErrMessage.char( "文档名称")) + hit_handling_method = serializers.CharField(required=False, validators=[ + validators.RegexValidator(regex=re.compile("^optimization|directly_return$"), + message="类型只支持optimization|directly_return", + code=500) + ], error_messages=ErrMessage.char("命中处理方式")) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char( "文档是否可用")) @@ -116,12 +124,15 @@ class Query(ApiMixin, serializers.Serializer): min_length=1, error_messages=ErrMessage.char( "文档名称")) + hit_handling_method = serializers.CharField(required=False, error_messages=ErrMessage.char("命中处理方式")) def get_query_set(self): query_set = QuerySet(model=Document) query_set = query_set.filter(**{'dataset_id': self.data.get("dataset_id")}) if 'name' in self.data and self.data.get('name') is not None: query_set = query_set.filter(**{'name__icontains': self.data.get('name')}) + if 'hit_handling_method' in self.data and self.data.get('hit_handling_method') is not None: + query_set = query_set.filter(**{'hit_handling_method': self.data.get('hit_handling_method')}) query_set = query_set.order_by('-create_time') return query_set @@ -143,7 +154,11 @@ def get_request_params_api(): in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False, - description='文档名称')] + description='文档名称'), + openapi.Parameter(name='hit_handling_method', in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='文档命中处理方式')] @staticmethod def get_response_body_api(): @@ -252,7 +267,7 @@ def edit(self, instance: Dict, with_valid=False): _document = QuerySet(Document).get(id=self.data.get("document_id")) if with_valid: DocumentEditInstanceSerializer(data=instance).is_valid(document=_document) - update_keys = ['name', 'is_active', 'meta'] + update_keys = ['name', 'is_active', 'hit_handling_method', 'meta'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: _document.__setattr__(update_key, instance.get(update_key)) @@ -320,6 +335,8 @@ def get_request_body_api(): properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"), 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), + 'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式", + description="ai优化:optimization,直接返回:directly_return"), 'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title="文档元数据", description="文档元数据->web:{source_url:xxx,selector:'xxx'},base:{}"), } diff --git a/ui/src/styles/app.scss b/ui/src/styles/app.scss index bf03e7fde6b..6ee91caafab 100644 --- a/ui/src/styles/app.scss +++ b/ui/src/styles/app.scss @@ -189,8 +189,8 @@ h4 { padding-bottom: 0; } -.float-right{ - float:right; +.float-right { + float: right; } .flex { @@ -217,6 +217,10 @@ h4 { align-items: baseline; } +.justify-center { + justify-content: center; +} + .text-left { text-align: left; } @@ -565,4 +569,4 @@ h4 { .title { color: var(--app-text-color); } -} \ No newline at end of file +} diff --git a/ui/src/styles/element-plus.scss b/ui/src/styles/element-plus.scss index b85ad4f7729..8aa2ef73473 100644 --- a/ui/src/styles/element-plus.scss +++ b/ui/src/styles/element-plus.scss @@ -24,6 +24,10 @@ background-color: var(--el-button-bg-color); border-color: var(--el-button-border-color); } + &.is-link:focus { + background: none; + border: none; + } } .el-button--large { font-size: 16px; @@ -137,10 +141,16 @@ color: var(--app-text-color); font-weight: 400; padding: 5px 11px; - &:not(.is-disabled):focus { + &:not(.is-disabled):focus, + &:not(.is-active):focus { background-color: var(--app-text-color-light-1); color: var(--app-text-color); } + &.is-active, + &.is-active:hover { + color: var(--el-menu-active-color); + background: var(--el-color-primary-light-9); + } } .el-tag { diff --git a/ui/src/views/document/component/ImportDocumentDialog.vue b/ui/src/views/document/component/ImportDocumentDialog.vue index 5a921202e98..f0ddc29eb70 100644 --- a/ui/src/views/document/component/ImportDocumentDialog.vue +++ b/ui/src/views/document/component/ImportDocumentDialog.vue @@ -21,15 +21,34 @@ type="textarea" /> - + - + + + + + + + + + + + - +