Skip to content

Commit d48b51c

Browse files
committed
feat: 支持用户输入变量
--story=1016155 --user=刘瑞斌 【应用编排】-支持设置用户输入变量 https://www.tapd.cn/57709429/s/1576480
1 parent ba023d2 commit d48b51c

File tree

9 files changed

+438
-45
lines changed

9 files changed

+438
-45
lines changed

apps/application/flow/workflow_manage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,10 @@ def is_valid_base_node(self):
166166

167167
class WorkflowManage:
168168
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
169-
base_to_response: BaseToResponse = SystemToResponse()):
169+
base_to_response: BaseToResponse = SystemToResponse(), form_data = {}):
170170
self.params = params
171171
self.flow = flow
172-
self.context = {}
172+
self.context = form_data
173173
self.node_context = []
174174
self.work_flow_post_handler = work_flow_post_handler
175175
self.current_node = None

apps/application/serializers/application_serializers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,7 @@ def profile(self, with_valid=True):
694694
'tts_model_id': application.tts_model_id,
695695
'stt_model_enable': application.stt_model_enable,
696696
'tts_model_enable': application.tts_model_enable,
697+
'work_flow': application.work_flow,
697698
'show_source': application_access_token.show_source})
698699

699700
@transaction.atomic
@@ -855,10 +856,15 @@ def get_work_flow_model(instance):
855856
nodes = instance.get('work_flow')['nodes']
856857
for node in nodes:
857858
if node['id'] == 'base-node':
858-
instance['stt_model_id'] = node['properties']['node_data']['stt_model_id']
859-
instance['tts_model_id'] = node['properties']['node_data']['tts_model_id']
860-
instance['stt_model_enable'] = node['properties']['node_data']['stt_model_enable']
861-
instance['tts_model_enable'] = node['properties']['node_data']['tts_model_enable']
859+
node_data = node['properties']['node_data']
860+
if 'stt_model_id' in node_data:
861+
instance['stt_model_id'] = node_data['stt_model_id']
862+
if 'tts_model_id' in node_data:
863+
instance['tts_model_id'] = node_data['tts_model_id']
864+
if 'stt_model_enable' in node_data:
865+
instance['stt_model_enable'] = node_data['stt_model_enable']
866+
if 'tts_model_enable' in node_data:
867+
instance['tts_model_enable'] = node_data['tts_model_enable']
862868
break
863869

864870
def speech_to_text(self, file, with_valid=True):

apps/application/serializers/chat_message_serializers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ class ChatMessageSerializer(serializers.Serializer):
208208
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
209209
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
210210
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
211+
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
211212

212213
def is_valid_application_workflow(self, *, raise_exception=False):
213214
self.is_valid_intraday_access_num()
@@ -284,14 +285,15 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
284285
stream = self.data.get('stream')
285286
client_id = self.data.get('client_id')
286287
client_type = self.data.get('client_type')
288+
form_data = self.data.get('form_data')
287289
user_id = chat_info.application.user_id
288290
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
289291
{'history_chat_record': chat_info.chat_record_list, 'question': message,
290292
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
291293
'stream': stream,
292294
're_chat': re_chat,
293295
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
294-
base_to_response)
296+
base_to_response, form_data)
295297
r = work_flow_manage.run()
296298
return r
297299

apps/application/views/chat_views.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def post(self, request: Request, chat_id: str):
126126
'application_id': (request.auth.keywords.get(
127127
'application_id') if request.auth.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value else None),
128128
'client_id': request.auth.client_id,
129+
'form_data': (request.data.get('form_data') if 'form_data' in request.data else []),
129130
'client_type': request.auth.client_type}).chat()
130131

131132
@action(methods=['GET'], detail=False)

0 commit comments

Comments
 (0)