From 1943ff6fce7a976ae382e46c9586f2751812ee50 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Wed, 3 Jul 2024 15:16:54 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=A4=A7=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E8=BF=94=E5=9B=9Ejson=E6=97=B6=EF=BC=8C=E8=A7=A3?= =?UTF-8?q?=E6=9E=90=E5=87=BA=E9=94=99=20#656?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../step/chat_step/impl/base_chat_step.py | 15 +++++++++------ .../impl/base_generate_human_message_step.py | 6 +++--- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index b09ce7e9755..f7dbe58350e 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -143,7 +143,8 @@ def reset_message_list(message_list: List[BaseMessage], answer_text): def get_stream_result(message_list: List[BaseMessage], chat_model: BaseChatModel = None, paragraph_list=None, - no_references_setting=None): + no_references_setting=None, + problem_text=None): if paragraph_list is None: paragraph_list = [] directly_return_chunk_list = [AIMessageChunk(content=paragraph.content) @@ -153,7 +154,8 @@ def get_stream_result(message_list: List[BaseMessage], return iter(directly_return_chunk_list), False elif len(paragraph_list) == 0 and no_references_setting.get( 'status') == 'designated_answer': - return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False + return iter( + [AIMessageChunk(content=no_references_setting.get('value').replace('{question}', problem_text))]), False if chat_model is None: return iter([AIMessageChunk('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。')]), False else: @@ -170,7 +172,7 @@ def execute_stream(self, message_list: List[BaseMessage], client_id=None, client_type=None, no_references_setting=None): chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, - no_references_setting) + no_references_setting, problem_text) chat_record_id = uuid.uuid1() r = StreamingHttpResponse( streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, @@ -185,7 +187,8 @@ def execute_stream(self, message_list: List[BaseMessage], def get_block_result(message_list: List[BaseMessage], chat_model: BaseChatModel = None, paragraph_list=None, - no_references_setting=None): + no_references_setting=None, + problem_text=None): if paragraph_list is None: paragraph_list = [] @@ -196,7 +199,7 @@ def get_block_result(message_list: List[BaseMessage], return directly_return_chunk_list[0], False elif len(paragraph_list) == 0 and no_references_setting.get( 'status') == 'designated_answer': - return AIMessage(no_references_setting.get('value')), False + return AIMessage(no_references_setting.get('value').replace('{question}', problem_text)), False if chat_model is None: return AIMessage('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。'), False else: @@ -215,7 +218,7 @@ def execute_block(self, message_list: List[BaseMessage], # 调用模型 try: chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, - no_references_setting) + no_references_setting, problem_text) if is_ai_chat: request_token = chat_model.get_num_tokens_from_messages(message_list) response_token = chat_model.get_num_tokens(chat_result.content) diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py index 6664a286c54..8b769c77002 100644 --- a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py +++ b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py @@ -48,9 +48,9 @@ def to_human_message(prompt: str, if paragraph_list is None or len(paragraph_list) == 0: if no_references_setting.get('status') == 'ai_questioning': return HumanMessage( - content=no_references_setting.get('value').format(**{'question': problem})) + content=no_references_setting.get('value').replace('{question}', problem)) else: - return HumanMessage(content=prompt.format(**{'data': "", 'question': problem})) + return HumanMessage(content=prompt.replace('{data}', "").replace('{question}', problem)) temp_data = "" data_list = [] for p in paragraph_list: @@ -63,4 +63,4 @@ def to_human_message(prompt: str, else: data_list.append(f"{content}") data = "\n".join(data_list) - return HumanMessage(content=prompt.format(**{'data': data, 'question': problem})) + return HumanMessage(content=prompt.replace('{data}', data).replace('{question}', problem))