Skip to content

Commit b0f443f

Browse files
authored
feat: 简易应用页面参数优化 (#1182)
1 parent 0b64e7a commit b0f443f

File tree

16 files changed

+306
-150
lines changed

16 files changed

+306
-150
lines changed

apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class InstanceSerializer(serializers.Serializer):
3737
"最大携带知识库段落长度"))
3838
# 模板
3939
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
40+
system = serializers.CharField(required=False, allow_null=True, allow_blank=True,
41+
error_messages=ErrMessage.char("系统提示词(角色)"))
4042
# 补齐问题
4143
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题"))
4244
# 未查询到引用分段
@@ -59,6 +61,7 @@ def execute(self,
5961
prompt: str,
6062
padding_problem_text: str = None,
6163
no_references_setting=None,
64+
system=None,
6265
**kwargs) -> List[BaseMessage]:
6366
"""
6467
@@ -71,6 +74,7 @@ def execute(self,
7174
:param padding_problem_text 用户修改文本
7275
:param kwargs: 其他参数
7376
:param no_references_setting: 无引用分段设置
77+
:param system 系统提示称
7478
:return:
7579
"""
7680
pass

apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import List, Dict
1010

1111
from langchain.schema import BaseMessage, HumanMessage
12+
from langchain_core.messages import SystemMessage
1213

1314
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
1415
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
@@ -27,6 +28,7 @@ def execute(self, problem_text: str,
2728
prompt: str,
2829
padding_problem_text: str = None,
2930
no_references_setting=None,
31+
system=None,
3032
**kwargs) -> List[BaseMessage]:
3133
prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get(
3234
'value')
@@ -35,6 +37,11 @@ def execute(self, problem_text: str,
3537
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
3638
for index in
3739
range(start_index if start_index > 0 else 0, len(history_chat_record))]
40+
if system is not None and len(system) > 0:
41+
return [SystemMessage(system), *flat_map(history_message),
42+
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
43+
no_references_setting)]
44+
3845
return [*flat_map(history_message),
3946
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
4047
no_references_setting)]

apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class InstanceSerializer(serializers.Serializer):
2929
error_messages=ErrMessage.list("历史对答"))
3030
# 大语言模型
3131
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型"))
32+
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
33+
error_messages=ErrMessage.char("问题补全提示词"))
3234

3335
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
3436
return self.InstanceSerializer
@@ -47,5 +49,6 @@ def _run(self, manage: PipelineManage):
4749

4850
@abstractmethod
4951
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
52+
problem_optimization_prompt=None,
5053
**kwargs):
5154
pass

apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
class BaseResetProblemStep(IResetProblemStep):
2323
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
24+
problem_optimization_prompt=None,
2425
**kwargs) -> str:
2526
if chat_model is None:
2627
self.context['message_tokens'] = 0
@@ -30,15 +31,19 @@ def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = Non
3031
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
3132
for index in
3233
range(start_index if start_index > 0 else 0, len(history_chat_record))]
34+
reset_prompt = problem_optimization_prompt if problem_optimization_prompt else prompt
3335
message_list = [*flat_map(history_message),
34-
HumanMessage(content=prompt.format(**{'question': problem_text}))]
36+
HumanMessage(content=reset_prompt.replace('{question}', problem_text))]
3537
response = chat_model.invoke(message_list)
3638
padding_problem = problem_text
3739
if response.content.__contains__("<data>") and response.content.__contains__('</data>'):
3840
padding_problem_data = response.content[
3941
response.content.index('<data>') + 6:response.content.index('</data>')]
4042
if padding_problem_data is not None and len(padding_problem_data.strip()) > 0:
4143
padding_problem = padding_problem_data
44+
elif len(response.content) > 0:
45+
padding_problem = response.content
46+
4247
try:
4348
request_token = chat_model.get_num_tokens_from_messages(message_list)
4449
response_token = chat_model.get_num_tokens(padding_problem)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Generated by Django 4.2.15 on 2024-09-13 18:57
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
('application', '0013_application_tts_type'),
10+
]
11+
12+
operations = [
13+
migrations.AddField(
14+
model_name='application',
15+
name='problem_optimization_prompt',
16+
field=models.CharField(blank=True, default='()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中', max_length=102400, null=True, verbose_name='问题优化提示词'),
17+
),
18+
]

apps/application/models/application.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_dataset_setting_dict():
3535

3636

3737
def get_model_setting_dict():
38-
return {'prompt': Application.get_default_model_prompt()}
38+
return {'prompt': Application.get_default_model_prompt(), 'no_references_prompt': '{question}'}
3939

4040

4141
class Application(AppModelMixin):
@@ -54,8 +54,13 @@ class Application(AppModelMixin):
5454
work_flow = models.JSONField(verbose_name="工作流数据", default=dict)
5555
type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices.choices,
5656
default=ApplicationTypeChoices.SIMPLE, max_length=256)
57-
tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
58-
stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
57+
problem_optimization_prompt = models.CharField(verbose_name="问题优化提示词", max_length=102400, blank=True,
58+
null=True,
59+
default="()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中")
60+
tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False,
61+
blank=True, null=True)
62+
stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False,
63+
blank=True, null=True)
5964
tts_model_enable = models.BooleanField(verbose_name="语音合成模型是否启用", default=False)
6065
stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False)
6166
tts_type = models.CharField(verbose_name="语音播放类型", max_length=20, default="BROWSER")

apps/application/serializers/application_serializers.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,12 @@ class DatasetSettingSerializer(serializers.Serializer):
120120

121121

122122
class ModelSettingSerializer(serializers.Serializer):
123-
prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词"))
123+
prompt = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
124+
error_messages=ErrMessage.char("提示词"))
125+
system = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400,
126+
error_messages=ErrMessage.char("角色提示词"))
127+
no_references_prompt = serializers.CharField(required=True, max_length=102400, allow_null=True, allow_blank=True,
128+
error_messages=ErrMessage.char("无引用分段提示词"))
124129

125130

126131
class ApplicationWorkflowSerializer(serializers.Serializer):
@@ -174,7 +179,7 @@ class ApplicationSerializer(serializers.Serializer):
174179
error_messages=ErrMessage.char("应用描述"))
175180
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
176181
error_messages=ErrMessage.char("模型"))
177-
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
182+
dialogue_number = serializers.BooleanField(required=True, error_messages=ErrMessage.char("会话次数"))
178183
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096,
179184
error_messages=ErrMessage.char("开场白"))
180185
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
@@ -185,6 +190,8 @@ class ApplicationSerializer(serializers.Serializer):
185190
model_setting = ModelSettingSerializer(required=True)
186191
# 问题补全
187192
problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全"))
193+
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
194+
error_messages=ErrMessage.char("问题补全提示词"))
188195
# 应用类型
189196
type = serializers.CharField(required=True, error_messages=ErrMessage.char("应用类型"),
190197
validators=[
@@ -364,8 +371,8 @@ class Edit(serializers.Serializer):
364371
error_messages=ErrMessage.char("应用描述"))
365372
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True,
366373
error_messages=ErrMessage.char("模型"))
367-
multiple_rounds_dialogue = serializers.BooleanField(required=False,
368-
error_messages=ErrMessage.boolean("多轮会话"))
374+
dialogue_number = serializers.IntegerField(required=False,
375+
error_messages=ErrMessage.boolean("多轮会话"))
369376
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096,
370377
error_messages=ErrMessage.char("开场白"))
371378
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
@@ -430,13 +437,14 @@ def insert_simple(self, application: Dict):
430437
def to_application_model(user_id: str, application: Dict):
431438
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
432439
prologue=application.get('prologue'),
433-
dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0,
440+
dialogue_number=application.get('dialogue_number', 0),
434441
user_id=user_id, model_id=application.get('model_id'),
435442
dataset_setting=application.get('dataset_setting'),
436443
model_setting=application.get('model_setting'),
437444
problem_optimization=application.get('problem_optimization'),
438445
type=ApplicationTypeChoices.SIMPLE,
439446
model_params_setting=application.get('model_params_setting', {}),
447+
problem_optimization_prompt=application.get('problem_optimization_prompt', None),
440448
work_flow={}
441449
)
442450

apps/application/serializers/chat_message_serializers.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ def __init__(self,
6060
self.chat_record_list: List[ChatRecord] = []
6161
self.work_flow_version = work_flow_version
6262

63+
@staticmethod
64+
def get_no_references_setting(dataset_setting, model_setting):
65+
no_references_setting = dataset_setting.get(
66+
'no_references_setting', {
67+
'status': 'ai_questioning',
68+
'value': '{question}'})
69+
if no_references_setting.get('status') == 'ai_questioning':
70+
no_references_prompt = model_setting.get('no_references_prompt', '{question}')
71+
no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}"
72+
return no_references_setting
73+
6374
def to_base_pipeline_manage_params(self):
6475
dataset_setting = self.application.dataset_setting
6576
model_setting = self.application.model_setting
@@ -80,20 +91,21 @@ def to_base_pipeline_manage_params(self):
8091
'history_chat_record': self.chat_record_list,
8192
'chat_id': self.chat_id,
8293
'dialogue_number': self.application.dialogue_number,
94+
'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len(
95+
self.application.problem_optimization_prompt) > 0 else '()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中',
8396
'prompt': model_setting.get(
84-
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
97+
'prompt') if 'prompt' in model_setting and len(model_setting.get(
98+
'prompt')) > 0 else Application.get_default_model_prompt(),
99+
'system': model_setting.get(
100+
'system', None),
85101
'model_id': model_id,
86102
'problem_optimization': self.application.problem_optimization,
87103
'stream': True,
88104
'model_params_setting': model_params_setting if self.application.model_params_setting is None or len(
89105
self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
90106
'search_mode': self.application.dataset_setting.get(
91107
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
92-
'no_references_setting': self.application.dataset_setting.get(
93-
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
94-
'status': 'ai_questioning',
95-
'value': '{question}',
96-
},
108+
'no_references_setting': self.get_no_references_setting(self.application.dataset_setting, model_setting),
97109
'user_id': self.application.user_id
98110
}
99111

0 commit comments

Comments
 (0)