Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Reasoning Content #2158

Merged
merged 13 commits into from
Feb 8, 2025
11 changes: 8 additions & 3 deletions apps/application/chat_pipeline/step/chat_step/i_chat_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,21 @@ class InstanceSerializer(serializers.Serializer):
post_response_handler = InstanceField(model_type=PostResponseHandler,
error_messages=ErrMessage.base(_("Post-processor")))
# 补全问题
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base(_("Completion Question")))
padding_problem_text = serializers.CharField(required=False,
error_messages=ErrMessage.base(_("Completion Question")))
# 是否使用流的形式输出
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output")))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id")))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type")))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base(_("No reference segment settings")))
no_references_setting = NoReferencesSetting(required=True,
error_messages=ErrMessage.base(_("No reference segment settings")))

user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))

model_setting = serializers.DictField(required=True, allow_null=True,
error_messages=ErrMessage.dict(_("Model settings")))

model_params_setting = serializers.DictField(required=False, allow_null=True,
error_messages=ErrMessage.dict(_("Model parameter settings")))

Expand All @@ -101,5 +106,5 @@ def execute(self, message_list: List[BaseMessage],
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
no_references_setting=None, model_params_setting=None, **kwargs):
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.flow.tools import Reasoning
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
from setting.models_provider.tools import get_model_instance_by_model_user_id
Expand Down Expand Up @@ -63,17 +64,37 @@ def event_content(response,
problem_text: str,
padding_problem_text: str = None,
client_id=None, client_type=None,
is_ai_chat: bool = None):
is_ai_chat: bool = None,
model_setting=None):
if model_setting is None:
model_setting = {}
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
reasoning = Reasoning(reasoning_content_start,
reasoning_content_end)
all_text = ''
reasoning_content = ''
try:
for chunk in response:
all_text += chunk.content
reasoning_chunk = reasoning.get_reasoning_content(chunk)
content_chunk = reasoning_chunk.get('content')
if 'reasoning_content' in chunk.additional_kwargs:
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
else:
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
all_text += content_chunk
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
reasoning_content += reasoning_content_chunk
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], chunk.content,
[], content_chunk,
False,
0, 0, {'node_is_end': False,
'view_type': 'many_view',
'node_type': 'ai-chat-node'})
'node_type': 'ai-chat-node',
'real_node_id': 'ai-chat-node',
'reasoning_content': reasoning_content_chunk if reasoning_content_enable else ''})
# 获取token
if is_ai_chat:
try:
Expand All @@ -87,7 +108,8 @@ def event_content(response,
response_token = 0
write_context(step, manage, request_token, response_token, all_text)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text, client_id)
all_text, manage, step, padding_problem_text, client_id,
reasoning_content=reasoning_content if reasoning_content_enable else '')
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], '', True,
request_token, response_token,
Expand Down Expand Up @@ -122,17 +144,20 @@ def execute(self, message_list: List[BaseMessage],
client_id=None, client_type=None,
no_references_setting=None,
model_params_setting=None,
model_setting=None,
**kwargs):
chat_model = get_model_instance_by_model_user_id(model_id, user_id,
**model_params_setting) if model_id is not None else None
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type, no_references_setting)
manage, padding_problem_text, client_id, client_type, no_references_setting,
model_setting)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type, no_references_setting)
manage, padding_problem_text, client_id, client_type, no_references_setting,
model_setting)

def get_details(self, manage, **kwargs):
return {
Expand Down Expand Up @@ -187,14 +212,15 @@ def execute_stream(self, message_list: List[BaseMessage],
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None,
no_references_setting=None):
no_references_setting=None,
model_setting=None):
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
no_references_setting, problem_text)
chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text, client_id, client_type, is_ai_chat),
padding_problem_text, client_id, client_type, is_ai_chat, model_setting),
content_type='text/event-stream;charset=utf-8')

r['Cache-Control'] = 'no-cache'
Expand Down Expand Up @@ -230,7 +256,13 @@ def execute_block(self, message_list: List[BaseMessage],
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None, no_references_setting=None):
client_id=None, client_type=None, no_references_setting=None,
model_setting=None):
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
reasoning = Reasoning(reasoning_content_start,
reasoning_content_end)
chat_record_id = uuid.uuid1()
# 调用模型
try:
Expand All @@ -243,14 +275,23 @@ def execute_block(self, message_list: List[BaseMessage],
request_token = 0
response_token = 0
write_context(self, manage, request_token, response_token, chat_result.content)
reasoning.get_reasoning_content(chat_result)
reasoning_result = reasoning.get_reasoning_content(chat_result)
content = reasoning_result.get('content')
if 'reasoning_content' in chat_result.response_metadata:
reasoning_content = chat_result.response_metadata.get('reasoning_content', '')
else:
reasoning_content = reasoning_result.get('reasoning_content')
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
chat_result.content, manage, self, padding_problem_text, client_id)
chat_result.content, manage, self, padding_problem_text, client_id,
reasoning_content=reasoning_content if reasoning_content_enable else '')
add_access_num(client_id, client_type, manage.context.get('application_id'))
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id),
chat_result.content, True,
request_token, response_token)
content, True,
request_token, response_token,
{'reasoning_content': reasoning_content})
except Exception as e:
all_text = '异常' + str(e)
all_text = 'Exception:' + str(e)
write_context(self, manage, 0, 0, all_text)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, self, padding_problem_text, client_id)
Expand Down
10 changes: 8 additions & 2 deletions apps/application/flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,22 @@


class Answer:
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node):
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
reasoning_content):
self.view_type = view_type
self.content = content
self.reasoning_content = reasoning_content
self.runtime_node_id = runtime_node_id
self.chat_record_id = chat_record_id
self.child_node = child_node
self.real_node_id = real_node_id

def to_dict(self):
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
'chat_record_id': self.chat_record_id, 'child_node': self.child_node}
'chat_record_id': self.chat_record_id,
'child_node': self.child_node,
'reasoning_content': self.reasoning_content,
'real_node_id': self.real_node_id}


class NodeChunk:
Expand Down
8 changes: 6 additions & 2 deletions apps/application/flow/i_step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def handler(self, chat_id,
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
answer_text_list = workflow.get_answer_text_list()
answer_text = '\n\n'.join(answer['content'] for answer in answer_text_list)
answer_text = '\n\n'.join(
'\n\n'.join([a.get('content') for a in answer]) for answer in
answer_text_list)
if workflow.chat_record is not None:
chat_record = workflow.chat_record
chat_record.answer_text = answer_text
Expand Down Expand Up @@ -157,8 +159,10 @@ def save_context(self, details, workflow_manage):
def get_answer_list(self) -> List[Answer] | None:
if self.answer_text is None:
return None
reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False)
return [
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {})]
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]

def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
get_node_params=lambda node: node.properties.get('node_data')):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class ChatNodeSerializer(serializers.Serializer):
error_messages=ErrMessage.boolean(_('Whether to return content')))

model_params_setting = serializers.DictField(required=False,
error_messages=ErrMessage.integer(_("Model parameter settings")))

error_messages=ErrMessage.dict(_("Model parameter settings")))
model_setting = serializers.DictField(required=False,
error_messages=ErrMessage.dict('Model settings'))
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char(_("Context Type")))

Expand All @@ -47,5 +48,6 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
chat_record_id,
model_params_setting=None,
dialogue_type=None,
model_setting=None,
**kwargs) -> NodeResult:
pass
Loading