diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 9c2e1c54429..e6b87f46cf8 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -10,6 +10,7 @@ from abc import abstractmethod from typing import Type, Dict, List +from django.core import cache from django.db.models import QuerySet from rest_framework import serializers @@ -18,7 +19,6 @@ from common.constants.authentication_type import AuthenticationType from common.field.common import InstanceField from common.util.field_message import ErrMessage -from django.core import cache chat_cache = cache.caches['chat_cache'] @@ -27,9 +27,13 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow): if step_variable is not None: for key in step_variable: node.context[key] = step_variable[key] + if workflow.is_result() and 'answer' in step_variable: + yield step_variable['answer'] + workflow.answer += step_variable['answer'] if global_variable is not None: for key in global_variable: workflow.context[key] = global_variable[key] + node.context['run_time'] = time.time() - node.context['start_time'] class WorkFlowPostHandler: @@ -70,18 +74,14 @@ def handler(self, chat_id, class NodeResult: - def __init__(self, node_variable: Dict, workflow_variable: Dict, _to_response=None, _write_context=write_context): + def __init__(self, node_variable: Dict, workflow_variable: Dict, + _write_context=write_context): self._write_context = _write_context self.node_variable = node_variable self.workflow_variable = workflow_variable - self._to_response = _to_response def write_context(self, node, workflow): - self._write_context(self.node_variable, self.workflow_variable, node, workflow) - - def to_response(self, chat_id, chat_record_id, node, workflow, post_handler: WorkFlowPostHandler): - return self._to_response(chat_id, chat_record_id, self.node_variable, self.workflow_variable, node, workflow, - post_handler) + return self._write_context(self.node_variable, self.workflow_variable, node, workflow) def is_assertion_result(self): return 'branch_id' in self.node_variable diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index 3b941922f1b..78cbc462c3f 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -22,6 +22,8 @@ class ChatNodeSerializer(serializers.Serializer): # 多轮对话数量 dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + class IChatNode(INode): type = 'ai-chat-node' diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index d8d087b0489..5fb38c12278 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -13,12 +13,25 @@ from langchain.schema import HumanMessage, SystemMessage from langchain_core.messages import BaseMessage -from application.flow import tools from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode from setting.models_provider.tools import get_model_instance_by_model_user_id +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + if workflow.is_result(): + workflow.answer += answer + + def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): """ 写入上下文数据 (流式) @@ -31,15 +44,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo answer = '' for chunk in response: answer += chunk.content - chat_model = node_variable.get('chat_model') - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - answer_tokens = chat_model.get_num_tokens(answer) - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['history_message'] = node_variable['history_message'] - node.context['question'] = node_variable['question'] - node.context['run_time'] = time.time() - node.context['start_time'] + yield answer + _write_context(node_variable, workflow_variable, node, workflow, answer) def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -51,71 +57,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor @param workflow: 工作流管理器 """ response = node_variable.get('result') - chat_model = node_variable.get('chat_model') answer = response.content - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - answer_tokens = chat_model.get_num_tokens(answer) - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['history_message'] = node_variable['history_message'] - node.context['question'] = node_variable['question'] - - -def get_to_response_write_context(node_variable: Dict, node: INode): - def _write_context(answer, status=200): - chat_model = node_variable.get('chat_model') - - if status == 200: - answer_tokens = chat_model.get_num_tokens(answer) - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - else: - answer_tokens = 0 - message_tokens = 0 - node.err_message = answer - node.status = status - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['run_time'] = time.time() - node.context['start_time'] - - return _write_context - - -def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将流式数据 转换为 流式响应 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 输出结果后执行 - @return: 流式响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) - - -def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将结果转换 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 - @return: 响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + _write_context(node_variable, workflow_variable, node, workflow, answer) class BaseChatNode(IChatNode): @@ -132,13 +75,12 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record r = chat_model.stream(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, 'history_message': history_message, 'question': question.content}, {}, - _write_context=write_context_stream, - _to_response=to_stream_response) + _write_context=write_context_stream) else: r = chat_model.invoke(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, 'history_message': history_message, 'question': question.content}, {}, - _write_context=write_context, _to_response=to_response) + _write_context=write_context) @staticmethod def get_history_message(history_chat_record, dialogue_number): diff --git a/apps/application/flow/step_node/direct_reply_node/i_reply_node.py b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py index 1d5256ac568..3c0f3587547 100644 --- a/apps/application/flow/step_node/direct_reply_node/i_reply_node.py +++ b/apps/application/flow/step_node/direct_reply_node/i_reply_node.py @@ -20,6 +20,7 @@ class ReplyNodeParamsSerializer(serializers.Serializer): fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段")) content = serializers.CharField(required=False, allow_blank=True, allow_null=True, error_messages=ErrMessage.char("直接回答内容")) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py index 717dce161e7..de79279d932 100644 --- a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py +++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -6,69 +6,19 @@ @date:2024/6/11 17:25 @desc: """ -from typing import List, Dict +from typing import List -from langchain_core.messages import AIMessage, AIMessageChunk - -from application.flow import tools -from application.flow.i_step_node import NodeResult, INode +from application.flow.i_step_node import NodeResult from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode -def get_to_response_write_context(node_variable: Dict, node: INode): - def _write_context(answer, status=200): - node.context['answer'] = answer - - return _write_context - - -def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将流式数据 转换为 流式响应 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 输出结果后执行 - @return: 流式响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) - - -def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将结果转换 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 - @return: 响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) - - class BaseReplyNode(IReplyNode): def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: if reply_type == 'referencing': result = self.get_reference_content(fields) else: result = self.generate_reply_content(content) - if stream: - return NodeResult({'result': iter([AIMessageChunk(content=result)]), 'answer': result}, {}, - _to_response=to_stream_response) - else: - return NodeResult({'result': AIMessage(content=result), 'answer': result}, {}, _to_response=to_response) + return NodeResult({'answer': result}, {}) def generate_reply_content(self, prompt): return self.workflow_manage.generate_prompt(prompt) diff --git a/apps/application/flow/step_node/question_node/i_question_node.py b/apps/application/flow/step_node/question_node/i_question_node.py index ede120defce..30790c7c6f9 100644 --- a/apps/application/flow/step_node/question_node/i_question_node.py +++ b/apps/application/flow/step_node/question_node/i_question_node.py @@ -22,6 +22,8 @@ class QuestionNodeSerializer(serializers.Serializer): # 多轮对话数量 dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + class IQuestionNode(INode): type = 'question-node' diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index f5855361e99..5367b1dc255 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -13,12 +13,25 @@ from langchain.schema import HumanMessage, SystemMessage from langchain_core.messages import BaseMessage -from application.flow import tools from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.question_node.i_question_node import IQuestionNode from setting.models_provider.tools import get_model_instance_by_model_user_id +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + if workflow.is_result(): + workflow.answer += answer + + def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): """ 写入上下文数据 (流式) @@ -31,15 +44,8 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo answer = '' for chunk in response: answer += chunk.content - chat_model = node_variable.get('chat_model') - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - answer_tokens = chat_model.get_num_tokens(answer) - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['history_message'] = node_variable['history_message'] - node.context['question'] = node_variable['question'] - node.context['run_time'] = time.time() - node.context['start_time'] + yield answer + _write_context(node_variable, workflow_variable, node, workflow, answer) def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -51,71 +57,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor @param workflow: 工作流管理器 """ response = node_variable.get('result') - chat_model = node_variable.get('chat_model') answer = response.content - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - answer_tokens = chat_model.get_num_tokens(answer) - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['history_message'] = node_variable['history_message'] - node.context['question'] = node_variable['question'] - - -def get_to_response_write_context(node_variable: Dict, node: INode): - def _write_context(answer, status=200): - chat_model = node_variable.get('chat_model') - - if status == 200: - answer_tokens = chat_model.get_num_tokens(answer) - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - else: - answer_tokens = 0 - message_tokens = 0 - node.err_message = answer - node.status = status - node.context['message_tokens'] = message_tokens - node.context['answer_tokens'] = answer_tokens - node.context['answer'] = answer - node.context['run_time'] = time.time() - node.context['start_time'] - - return _write_context - - -def to_stream_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将流式数据 转换为 流式响应 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 输出结果后执行 - @return: 流式响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_stream_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) - - -def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable: Dict, node, workflow, - post_handler): - """ - 将结果转换 - @param chat_id: 会话id - @param chat_record_id: 对话记录id - @param node_variable: 节点数据 - @param workflow_variable: 工作流数据 - @param node: 节点 - @param workflow: 工作流管理器 - @param post_handler: 后置处理器 - @return: 响应 - """ - response = node_variable.get('result') - _write_context = get_to_response_write_context(node_variable, node) - return tools.to_response(chat_id, chat_record_id, response, workflow, _write_context, post_handler) + _write_context(node_variable, workflow_variable, node, workflow, answer) class BaseQuestionNode(IQuestionNode): @@ -131,15 +74,13 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record if stream: r = chat_model.stream(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, - 'get_to_response_write_context': get_to_response_write_context, 'history_message': history_message, 'question': question.content}, {}, - _write_context=write_context_stream, - _to_response=to_stream_response) + _write_context=write_context_stream) else: r = chat_model.invoke(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, 'history_message': history_message, 'question': question.content}, {}, - _write_context=write_context, _to_response=to_response) + _write_context=write_context) @staticmethod def get_history_message(history_chat_record, dialogue_number): diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index 839aae8daab..b2bf6b1adde 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -85,3 +85,21 @@ def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_ post_handler.handler(chat_id, chat_record_id, answer, workflow) return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, 'content': answer, 'is_end': True}) + + +def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow, + post_handler: WorkFlowPostHandler): + answer = response.content + post_handler.handler(chat_id, chat_record_id, answer, workflow) + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': answer, 'is_end': True}) + + +def to_stream_response_simple(stream_event): + r = StreamingHttpResponse( + streaming_content=stream_event, + content_type='text/event-stream;charset=utf-8', + charset='utf-8') + + r['Cache-Control'] = 'no-cache' + return r diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 68a2ba0223f..cfc0d6404d9 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -6,10 +6,11 @@ @date:2024/1/9 17:40 @desc: """ +import json from functools import reduce from typing import List, Dict -from langchain_core.messages import AIMessageChunk, AIMessage +from langchain_core.messages import AIMessage from langchain_core.prompts import PromptTemplate from application.flow import tools @@ -63,7 +64,6 @@ def get_start_node(self): def get_search_node(self): return [node for node in self.nodes if node.type == 'search-dataset-node'] - def is_valid(self): """ 校验工作流数据 @@ -140,33 +140,71 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl self.work_flow_post_handler = work_flow_post_handler self.current_node = None self.current_result = None + self.answer = "" def run(self): - """ - 运行工作流 - """ + if self.params.get('stream'): + return self.run_stream() + return self.run_block() + + def run_block(self): try: while self.has_next_node(self.current_result): self.current_node = self.get_next_node() self.node_context.append(self.current_node) self.current_result = self.current_node.run() - if self.has_next_node(self.current_result): - self.current_result.write_context(self.current_node, self) - else: - r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'], - self.current_node, self, - self.work_flow_post_handler) - return r + result = self.current_result.write_context(self.current_node, self) + if result is not None: + list(result) + if not self.has_next_node(self.current_result): + return tools.to_response_simple(self.params['chat_id'], self.params['chat_record_id'], + AIMessage(self.answer), self, + self.work_flow_post_handler) except Exception as e: - if self.params.get('stream'): - return tools.to_stream_response(self.params['chat_id'], self.params['chat_record_id'], - iter([AIMessageChunk(str(e))]), self, - self.current_node.get_write_error_context(e), - self.work_flow_post_handler) - else: - return tools.to_response(self.params['chat_id'], self.params['chat_record_id'], - AIMessage(str(e)), self, self.current_node.get_write_error_context(e), - self.work_flow_post_handler) + return tools.to_response(self.params['chat_id'], self.params['chat_record_id'], + AIMessage(str(e)), self, self.current_node.get_write_error_context(e), + self.work_flow_post_handler) + + def run_stream(self): + return tools.to_stream_response_simple(self.stream_event()) + + def stream_event(self): + try: + while self.has_next_node(self.current_result): + self.current_node = self.get_next_node() + self.node_context.append(self.current_node) + self.current_result = self.current_node.run() + result = self.current_result.write_context(self.current_node, self) + if result is not None: + for r in result: + if self.is_result(): + yield self.get_chunk_content(r) + if not self.has_next_node(self.current_result): + yield self.get_chunk_content('', True) + break + self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], + self.answer, + self) + except Exception as e: + self.current_node.get_write_error_context(e) + self.answer += str(e) + self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], + self.answer, + self) + yield self.get_chunk_content(str(e), True) + + def is_result(self): + """ + 判断是否是返回节点 + @return: + """ + return self.current_node.node_params.get('is_result', not self.has_next_node( + self.current_result)) if self.current_node.node_params is not None else False + + def get_chunk_content(self, chunk, is_end=False): + return 'data: ' + json.dumps( + {'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True, + 'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n" def has_next_node(self, node_result: NodeResult | None): """ diff --git a/ui/src/workflow/common/data.ts b/ui/src/workflow/common/data.ts index e4af2236343..c94d50ee1c5 100644 --- a/ui/src/workflow/common/data.ts +++ b/ui/src/workflow/common/data.ts @@ -170,3 +170,13 @@ export const nodeDict: any = { export function isWorkFlow(type: string | undefined) { return type === 'WORK_FLOW' } + +export function isLastNode(nodeModel: any) { + const incoming = nodeModel.graphModel.getNodeIncomingNode(nodeModel.id) + const outcomming = nodeModel.graphModel.getNodeOutgoingNode(nodeModel.id) + if (incoming.length > 0 && outcomming.length === 0) { + return true + } else { + return false + } +} diff --git a/ui/src/workflow/nodes/ai-chat-node/index.vue b/ui/src/workflow/nodes/ai-chat-node/index.vue index 4eff08f7b3a..bc041998cbe 100644 --- a/ui/src/workflow/nodes/ai-chat-node/index.vue +++ b/ui/src/workflow/nodes/ai-chat-node/index.vue @@ -132,6 +132,23 @@ class="w-full" /> + + + + @@ -156,6 +173,7 @@ import applicationApi from '@/api/application' import useStore from '@/stores' import { relatedObject } from '@/utils/utils' import type { Provider } from '@/api/type/model' +import { isLastNode } from '@/workflow/common/data' const { model } = useStore() const isKeyDown = ref(false) @@ -180,7 +198,8 @@ const form = { model_id: '', system: '', prompt: defaultPrompt, - dialogue_number: 1 + dialogue_number: 1, + is_result: false } const chat_data = computed({ @@ -240,6 +259,12 @@ const openCreateModel = (provider?: Provider) => { onMounted(() => { getProvider() getModel() + if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') { + if (isLastNode(props.nodeModel)) { + set(props.nodeModel.properties.node_data, 'is_result', true) + } + } + set(props.nodeModel, 'validate', validate) }) diff --git a/ui/src/workflow/nodes/question-node/index.vue b/ui/src/workflow/nodes/question-node/index.vue index adac82aeadd..688d179765a 100644 --- a/ui/src/workflow/nodes/question-node/index.vue +++ b/ui/src/workflow/nodes/question-node/index.vue @@ -133,6 +133,23 @@ class="w-full" /> + + + + @@ -156,6 +173,8 @@ import applicationApi from '@/api/application' import useStore from '@/stores' import { relatedObject } from '@/utils/utils' import type { Provider } from '@/api/type/model' +import { isLastNode } from '@/workflow/common/data' + const { model } = useStore() const isKeyDown = ref(false) const wheel = (e: any) => { @@ -177,7 +196,8 @@ const form = { model_id: '', system: '你是一个问题优化大师', prompt: defaultPrompt, - dialogue_number: 1 + dialogue_number: 1, + is_result: false } const form_data = computed({ @@ -237,6 +257,11 @@ const openCreateModel = (provider?: Provider) => { onMounted(() => { getProvider() getModel() + if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') { + if (isLastNode(props.nodeModel)) { + set(props.nodeModel.properties.node_data, 'is_result', true) + } + } set(props.nodeModel, 'validate', validate) }) diff --git a/ui/src/workflow/nodes/reply-node/index.vue b/ui/src/workflow/nodes/reply-node/index.vue index 6e395cf4eab..5187b4741ee 100644 --- a/ui/src/workflow/nodes/reply-node/index.vue +++ b/ui/src/workflow/nodes/reply-node/index.vue @@ -46,6 +46,23 @@ v-model="form_data.fields" /> + + + + @@ -64,12 +81,14 @@ import { set } from 'lodash' import NodeContainer from '@/workflow/common/NodeContainer.vue' import NodeCascader from '@/workflow/common/NodeCascader.vue' import { ref, computed, onMounted } from 'vue' +import { isLastNode } from '@/workflow/common/data' const props = defineProps<{ nodeModel: any }>() const form = { reply_type: 'content', content: '', - fields: [] + fields: [], + is_result: false } const footers: any = [null, '=', 0] @@ -111,6 +130,12 @@ const validate = () => { } onMounted(() => { + if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') { + if (isLastNode(props.nodeModel)) { + set(props.nodeModel.properties.node_data, 'is_result', true) + } + } + set(props.nodeModel, 'validate', validate) })