diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index e6a1b411726..1d7a7c0ab7f 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -166,10 +166,10 @@ def is_valid_base_node(self): class WorkflowManage: def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, - base_to_response: BaseToResponse = SystemToResponse()): + base_to_response: BaseToResponse = SystemToResponse(), form_data = {}): self.params = params self.flow = flow - self.context = {} + self.context = form_data self.node_context = [] self.work_flow_post_handler = work_flow_post_handler self.current_node = None diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 4dd65d12267..7eab38f90c2 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -694,6 +694,7 @@ def profile(self, with_valid=True): 'tts_model_id': application.tts_model_id, 'stt_model_enable': application.stt_model_enable, 'tts_model_enable': application.tts_model_enable, + 'work_flow': application.work_flow, 'show_source': application_access_token.show_source}) @transaction.atomic @@ -855,10 +856,15 @@ def get_work_flow_model(instance): nodes = instance.get('work_flow')['nodes'] for node in nodes: if node['id'] == 'base-node': - instance['stt_model_id'] = node['properties']['node_data']['stt_model_id'] - instance['tts_model_id'] = node['properties']['node_data']['tts_model_id'] - instance['stt_model_enable'] = node['properties']['node_data']['stt_model_enable'] - instance['tts_model_enable'] = node['properties']['node_data']['tts_model_enable'] + node_data = node['properties']['node_data'] + if 'stt_model_id' in node_data: + instance['stt_model_id'] = node_data['stt_model_id'] + if 'tts_model_id' in node_data: + instance['tts_model_id'] = node_data['tts_model_id'] + if 'stt_model_enable' in node_data: + instance['stt_model_enable'] = node_data['stt_model_enable'] + if 'tts_model_enable' in node_data: + instance['tts_model_enable'] = node_data['tts_model_enable'] break def speech_to_text(self, file, with_valid=True): diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index a570f1ff77f..8fbf0dbbc65 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -208,6 +208,7 @@ class ChatMessageSerializer(serializers.Serializer): application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) + form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量")) def is_valid_application_workflow(self, *, raise_exception=False): self.is_valid_intraday_access_num() @@ -284,6 +285,7 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response): stream = self.data.get('stream') client_id = self.data.get('client_id') client_type = self.data.get('client_type') + form_data = self.data.get('form_data') user_id = chat_info.application.user_id work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), {'history_chat_record': chat_info.chat_record_list, 'question': message, @@ -291,7 +293,7 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response): 'stream': stream, 're_chat': re_chat, 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), - base_to_response) + base_to_response, form_data) r = work_flow_manage.run() return r diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 288e8a1fc86..48648809c48 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -126,6 +126,7 @@ def post(self, request: Request, chat_id: str): 'application_id': (request.auth.keywords.get( 'application_id') if request.auth.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value else None), 'client_id': request.auth.client_id, + 'form_data': (request.data.get('form_data') if 'form_data' in request.data else []), 'client_type': request.auth.client_type}).chat() @action(methods=['GET'], detail=False) diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 1d5a9426924..7a75f994f3c 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -17,7 +17,9 @@ class="problem-button ellipsis-2 mb-8" :class="log ? 'disabled' : 'cursor'" > - + + + {{ item.str }} +
+
+ + +
+
+ + + +
+
+