Skip to content

Pr@main@hit handling method #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 25, 2024
Merged
11 changes: 9 additions & 2 deletions apps/application/chat_pipeline/I_base_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
class ParagraphPipelineModel:

def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str):
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
hit_handling_method: str):
self.id = _id
self.document_id = document_id
self.dataset_id = dataset_id
Expand All @@ -30,6 +31,7 @@ def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, ti
self.similarity = similarity
self.dataset_name = dataset_name
self.document_name = document_name
self.hit_handling_method = hit_handling_method

def to_dict(self):
return {
Expand All @@ -53,6 +55,7 @@ def __init__(self):
self.comprehensive_score = None
self.document_name = None
self.dataset_name = None
self.hit_handling_method = None

def add_paragraph(self, paragraph):
if isinstance(paragraph, Paragraph):
Expand All @@ -76,6 +79,10 @@ def add_document_name(self, document_name):
self.document_name = document_name
return self

def add_hit_handling_method(self, hit_handling_method):
self.hit_handling_method = hit_handling_method
return self

def add_comprehensive_score(self, comprehensive_score: float):
self.comprehensive_score = comprehensive_score
return self
Expand All @@ -91,7 +98,7 @@ def build(self):
self.paragraph.get('status'),
self.paragraph.get('is_active'),
self.comprehensive_score, self.similarity, self.dataset_name,
self.document_name)
self.document_name, self.hit_handling_method)


class IBaseChatPipelineStep:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,17 @@ def execute_stream(self, message_list: List[BaseMessage],
'status') == 'designated_answer':
chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))])
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
if paragraph_list is not None and len(paragraph_list) > 0:
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.stream(message_list)
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True

chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def execute(self, problem_text: str, dataset_id_list: list[str], exclude_documen
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
if embedding_list is None:
return []
paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
return [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
paragraph_list = self.list_paragraph(embedding_list, vector)
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
return result

@staticmethod
def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineModel:
Expand All @@ -50,10 +51,21 @@ def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineM
.add_comprehensive_score(find_embedding.get('comprehensive_score'))
.add_dataset_name(paragraph.get('dataset_name'))
.add_document_name(paragraph.get('document_name'))
.add_hit_handling_method(paragraph.get('hit_handling_method'))
.build())

@staticmethod
def list_paragraph(paragraph_id_list: List, vector):
def get_similarity(paragraph, embedding_list: List):
filter_embedding_list = [embedding for embedding in embedding_list if
str(embedding.get('paragraph_id')) == str(paragraph.get('id'))]
if filter_embedding_list is not None and len(filter_embedding_list) > 0:
find_embedding = filter_embedding_list[-1]
return find_embedding.get('comprehensive_score')
return 0

@staticmethod
def list_paragraph(embedding_list: List, vector):
paragraph_id_list = [row.get('paragraph_id') for row in embedding_list]
if paragraph_id_list is None or len(paragraph_id_list) == 0:
return []
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
Expand All @@ -67,6 +79,13 @@ def list_paragraph(paragraph_id_list: List, vector):
for paragraph_id in paragraph_id_list:
if not exist_paragraph_list.__contains__(paragraph_id):
vector.delete_by_paragraph_id(paragraph_id)
# 如果存在直接返回的则取直接返回段落
hit_handling_method_paragraph = [paragraph for paragraph in paragraph_list if
paragraph.get('hit_handling_method') == 'directly_return']
if len(hit_handling_method_paragraph) > 0:
# 找到评分最高的
return [sorted(hit_handling_method_paragraph,
key=lambda p: BaseSearchDatasetStep.get_similarity(p, embedding_list))[-1]]
return paragraph_list

def get_details(self, manage, **kwargs):
Expand Down
8 changes: 5 additions & 3 deletions apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,11 @@ def chat(self):
exclude_paragraph_id_list = []
# 相同问题是否需要排除已经查询到的段落
if re_chat:
paragraph_id_list = flat_map([row.paragraph_id_list for row in
filter(lambda chat_record: chat_record == message,
chat_info.chat_record_list)])
paragraph_id_list = flat_map(
[[paragraph.get('id') for paragraph in chat_record.details['search_step']['paragraph_list']] for
chat_record in chat_info.chat_record_list if
chat_record.problem_text == message and 'search_step' in chat_record.details and 'paragraph_list' in
chat_record.details['search_step']])
exclude_paragraph_id_list = list(set(paragraph_id_list))
# 构建运行参数
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
SELECT
paragraph.*,
dataset."name" AS "dataset_name",
"document"."name" AS "document_name"
"document"."name" AS "document_name",
"document"."hit_handling_method" AS "hit_handling_method"
FROM
paragraph paragraph
LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id
Expand Down
18 changes: 18 additions & 0 deletions apps/dataset/migrations/0003_document_hit_handling_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.1.13 on 2024-04-24 15:36

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('dataset', '0002_image'),
]

operations = [
migrations.AddField(
model_name='document',
name='hit_handling_method',
field=models.CharField(choices=[('optimization', '模型优化'), ('directly_return', '直接返回')], default='optimization', max_length=20, verbose_name='命中处理方式'),
),
]
8 changes: 8 additions & 0 deletions apps/dataset/models/data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class Type(models.TextChoices):
web = 1, 'web站点类型'


class HitHandlingMethod(models.TextChoices):
optimization = 'optimization', '模型优化'
directly_return = 'directly_return', '直接返回'


class DataSet(AppModelMixin):
"""
数据集表
Expand Down Expand Up @@ -58,6 +63,9 @@ class Document(AppModelMixin):

type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
default=Type.base)
hit_handling_method = models.CharField(verbose_name='命中处理方式', max_length=20,
choices=HitHandlingMethod.choices,
default=HitHandlingMethod.optimization)

meta = models.JSONField(verbose_name="元数据", default=dict)

Expand Down
21 changes: 19 additions & 2 deletions apps/dataset/serializers/document_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
"""
import logging
import os
import re
import traceback
import uuid
from functools import reduce
from typing import List, Dict

from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
from drf_yasg import openapi
Expand Down Expand Up @@ -42,6 +44,12 @@ class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer):
name = serializers.CharField(required=False, max_length=128, min_length=1,
error_messages=ErrMessage.char(
"文档名称"))
hit_handling_method = serializers.CharField(required=False, validators=[
validators.RegexValidator(regex=re.compile("^optimization|directly_return$"),
message="类型只支持optimization|directly_return",
code=500)
], error_messages=ErrMessage.char("命中处理方式"))

is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char(
"文档是否可用"))

Expand Down Expand Up @@ -116,12 +124,15 @@ class Query(ApiMixin, serializers.Serializer):
min_length=1,
error_messages=ErrMessage.char(
"文档名称"))
hit_handling_method = serializers.CharField(required=False, error_messages=ErrMessage.char("命中处理方式"))

def get_query_set(self):
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'dataset_id': self.data.get("dataset_id")})
if 'name' in self.data and self.data.get('name') is not None:
query_set = query_set.filter(**{'name__icontains': self.data.get('name')})
if 'hit_handling_method' in self.data and self.data.get('hit_handling_method') is not None:
query_set = query_set.filter(**{'hit_handling_method': self.data.get('hit_handling_method')})
query_set = query_set.order_by('-create_time')
return query_set

Expand All @@ -143,7 +154,11 @@ def get_request_params_api():
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='文档名称')]
description='文档名称'),
openapi.Parameter(name='hit_handling_method', in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='文档命中处理方式')]

@staticmethod
def get_response_body_api():
Expand Down Expand Up @@ -252,7 +267,7 @@ def edit(self, instance: Dict, with_valid=False):
_document = QuerySet(Document).get(id=self.data.get("document_id"))
if with_valid:
DocumentEditInstanceSerializer(data=instance).is_valid(document=_document)
update_keys = ['name', 'is_active', 'meta']
update_keys = ['name', 'is_active', 'hit_handling_method', 'meta']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_document.__setattr__(update_key, instance.get(update_key))
Expand Down Expand Up @@ -320,6 +335,8 @@ def get_request_body_api():
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式",
description="ai优化:optimization,直接返回:directly_return"),
'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title="文档元数据",
description="文档元数据->web:{source_url:xxx,selector:'xxx'},base:{}"),
}
Expand Down
10 changes: 7 additions & 3 deletions ui/src/styles/app.scss
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ h4 {
padding-bottom: 0;
}

.float-right{
float:right;
.float-right {
float: right;
}

.flex {
Expand All @@ -217,6 +217,10 @@ h4 {
align-items: baseline;
}

.justify-center {
justify-content: center;
}

.text-left {
text-align: left;
}
Expand Down Expand Up @@ -565,4 +569,4 @@ h4 {
.title {
color: var(--app-text-color);
}
}
}
12 changes: 11 additions & 1 deletion ui/src/styles/element-plus.scss
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
background-color: var(--el-button-bg-color);
border-color: var(--el-button-border-color);
}
&.is-link:focus {
background: none;
border: none;
}
}
.el-button--large {
font-size: 16px;
Expand Down Expand Up @@ -137,10 +141,16 @@
color: var(--app-text-color);
font-weight: 400;
padding: 5px 11px;
&:not(.is-disabled):focus {
&:not(.is-disabled):focus,
&:not(.is-active):focus {
background-color: var(--app-text-color-light-1);
color: var(--app-text-color);
}
&.is-active,
&.is-active:hover {
color: var(--el-menu-active-color);
background: var(--el-color-primary-light-9);
}
}

.el-tag {
Expand Down
Loading