Skip to content

feat: 建议应用页面参数优化 #1182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class InstanceSerializer(serializers.Serializer):
"最大携带知识库段落长度"))
# 模板
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
system = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("系统提示词(角色)"))
# 补齐问题
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题"))
# 未查询到引用分段
Expand All @@ -59,6 +61,7 @@ def execute(self,
prompt: str,
padding_problem_text: str = None,
no_references_setting=None,
system=None,
**kwargs) -> List[BaseMessage]:
"""

Expand All @@ -71,6 +74,7 @@ def execute(self,
:param padding_problem_text 用户修改文本
:param kwargs: 其他参数
:param no_references_setting: 无引用分段设置
:param system 系统提示称
:return:
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List, Dict

from langchain.schema import BaseMessage, HumanMessage
from langchain_core.messages import SystemMessage

from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
Expand All @@ -27,6 +28,7 @@ def execute(self, problem_text: str,
prompt: str,
padding_problem_text: str = None,
no_references_setting=None,
system=None,
**kwargs) -> List[BaseMessage]:
prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get(
'value')
Expand All @@ -35,6 +37,11 @@ def execute(self, problem_text: str,
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
if system is not None and len(system) > 0:
return [SystemMessage(system), *flat_map(history_message),
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
no_references_setting)]

return [*flat_map(history_message),
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
no_references_setting)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class InstanceSerializer(serializers.Serializer):
error_messages=ErrMessage.list("历史对答"))
# 大语言模型
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型"))
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
error_messages=ErrMessage.char("问题补全提示词"))

def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
Expand All @@ -47,5 +49,6 @@ def _run(self, manage: PipelineManage):

@abstractmethod
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
problem_optimization_prompt=None,
**kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
problem_optimization_prompt=None,
**kwargs) -> str:
if chat_model is None:
self.context['message_tokens'] = 0
Expand All @@ -30,15 +31,19 @@ def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = Non
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
reset_prompt = problem_optimization_prompt if problem_optimization_prompt else prompt
message_list = [*flat_map(history_message),
HumanMessage(content=prompt.format(**{'question': problem_text}))]
HumanMessage(content=reset_prompt.replace('{question}', problem_text))]
response = chat_model.invoke(message_list)
padding_problem = problem_text
if response.content.__contains__("<data>") and response.content.__contains__('</data>'):
padding_problem_data = response.content[
response.content.index('<data>') + 6:response.content.index('</data>')]
if padding_problem_data is not None and len(padding_problem_data.strip()) > 0:
padding_problem = padding_problem_data
elif len(response.content) > 0:
padding_problem = response.content

try:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(padding_problem)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.15 on 2024-09-13 18:57

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('application', '0013_application_tts_type'),
]

operations = [
migrations.AddField(
model_name='application',
name='problem_optimization_prompt',
field=models.CharField(blank=True, default='()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中', max_length=102400, null=True, verbose_name='问题优化提示词'),
),
]
11 changes: 8 additions & 3 deletions apps/application/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_dataset_setting_dict():


def get_model_setting_dict():
return {'prompt': Application.get_default_model_prompt()}
return {'prompt': Application.get_default_model_prompt(), 'no_references_prompt': '{question}'}


class Application(AppModelMixin):
Expand All @@ -54,8 +54,13 @@ class Application(AppModelMixin):
work_flow = models.JSONField(verbose_name="工作流数据", default=dict)
type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices.choices,
default=ApplicationTypeChoices.SIMPLE, max_length=256)
tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
problem_optimization_prompt = models.CharField(verbose_name="问题优化提示词", max_length=102400, blank=True,
null=True,
default="()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中")
tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False,
blank=True, null=True)
stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False,
blank=True, null=True)
tts_model_enable = models.BooleanField(verbose_name="语音合成模型是否启用", default=False)
stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False)
tts_type = models.CharField(verbose_name="语音播放类型", max_length=20, default="BROWSER")
Expand Down
18 changes: 13 additions & 5 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ class DatasetSettingSerializer(serializers.Serializer):


class ModelSettingSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词"))
prompt = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
error_messages=ErrMessage.char("提示词"))
system = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
error_messages=ErrMessage.char("角色提示词"))
no_references_prompt = serializers.CharField(required=True, max_length=102400, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("无引用分段提示词"))


class ApplicationWorkflowSerializer(serializers.Serializer):
Expand Down Expand Up @@ -174,7 +179,7 @@ class ApplicationSerializer(serializers.Serializer):
error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
dialogue_number = serializers.BooleanField(required=True, error_messages=ErrMessage.char("会话次数"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096,
error_messages=ErrMessage.char("开场白"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
Expand All @@ -185,6 +190,8 @@ class ApplicationSerializer(serializers.Serializer):
model_setting = ModelSettingSerializer(required=True)
# 问题补全
problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全"))
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
error_messages=ErrMessage.char("问题补全提示词"))
# 应用类型
type = serializers.CharField(required=True, error_messages=ErrMessage.char("应用类型"),
validators=[
Expand Down Expand Up @@ -364,8 +371,8 @@ class Edit(serializers.Serializer):
error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean("多轮会话"))
dialogue_number = serializers.IntegerField(required=False,
error_messages=ErrMessage.boolean("多轮会话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096,
error_messages=ErrMessage.char("开场白"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
Expand Down Expand Up @@ -430,13 +437,14 @@ def insert_simple(self, application: Dict):
def to_application_model(user_id: str, application: Dict):
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
prologue=application.get('prologue'),
dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0,
dialogue_number=application.get('dialogue_number', 0),
user_id=user_id, model_id=application.get('model_id'),
dataset_setting=application.get('dataset_setting'),
model_setting=application.get('model_setting'),
problem_optimization=application.get('problem_optimization'),
type=ApplicationTypeChoices.SIMPLE,
model_params_setting=application.get('model_params_setting', {}),
problem_optimization_prompt=application.get('problem_optimization_prompt', None),
work_flow={}
)

Expand Down
24 changes: 18 additions & 6 deletions apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ def __init__(self,
self.chat_record_list: List[ChatRecord] = []
self.work_flow_version = work_flow_version

@staticmethod
def get_no_references_setting(dataset_setting, model_setting):
no_references_setting = dataset_setting.get(
'no_references_setting', {
'status': 'ai_questioning',
'value': '{question}'})
if no_references_setting.get('status') == 'ai_questioning':
no_references_prompt = model_setting.get('no_references_prompt', '{question}')
no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}"
return no_references_setting

def to_base_pipeline_manage_params(self):
dataset_setting = self.application.dataset_setting
model_setting = self.application.model_setting
Expand All @@ -80,20 +91,21 @@ def to_base_pipeline_manage_params(self):
'history_chat_record': self.chat_record_list,
'chat_id': self.chat_id,
'dialogue_number': self.application.dialogue_number,
'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len(
self.application.problem_optimization_prompt) > 0 else '()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中',
'prompt': model_setting.get(
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
'prompt') if 'prompt' in model_setting and len(model_setting.get(
'prompt')) > 0 else Application.get_default_model_prompt(),
'system': model_setting.get(
'system', None),
'model_id': model_id,
'problem_optimization': self.application.problem_optimization,
'stream': True,
'model_params_setting': model_params_setting if self.application.model_params_setting is None or len(
self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
'search_mode': self.application.dataset_setting.get(
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
'no_references_setting': self.application.dataset_setting.get(
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
'status': 'ai_questioning',
'value': '{question}',
},
'no_references_setting': self.get_no_references_setting(self.application.dataset_setting, model_setting),
'user_id': self.application.user_id
}

Expand Down
Loading
Loading