diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index 62273818cb3..23537f289be 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -7,6 +7,7 @@ @desc: """ from .ai_chat_step_node import * +from .application_node import BaseApplicationNode from .condition_node import * from .question_node import * from .search_dataset_node import * @@ -17,7 +18,7 @@ from .reranker_node import * node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode, - BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode] + BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/application_node/__init__.py b/apps/application/flow/step_node/application_node/__init__.py new file mode 100644 index 00000000000..d1ea91ca7f8 --- /dev/null +++ b/apps/application/flow/step_node/application_node/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +from .impl import * diff --git a/apps/application/flow/step_node/application_node/i_application_node.py b/apps/application/flow/step_node/application_node/i_application_node.py new file mode 100644 index 00000000000..b11fa00232f --- /dev/null +++ b/apps/application/flow/step_node/application_node/i_application_node.py @@ -0,0 +1,40 @@ +# coding=utf-8 +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class ApplicationNodeSerializer(serializers.Serializer): + application_id = serializers.CharField(required=True, error_messages=ErrMessage.char("应用id")) + question_reference_address = serializers.ListField(required=True, error_messages=ErrMessage.list("用户问题")) + api_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.list("api输入字段")) + user_input_field_list = serializers.ListField(required=False, error_messages=ErrMessage.uuid("用户输入字段")) + + +class IApplicationNode(INode): + type = 'application-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ApplicationNodeSerializer + + def _run(self): + question = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('question_reference_address')[0], + self.node_params_serializer.data.get('question_reference_address')[1:]) + kwargs = {} + for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []): + kwargs[api_input_field['variable']] = self.workflow_manage.get_reference_field(api_input_field['value'][0], + api_input_field['value'][1:]) + for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []): + kwargs[user_input_field['field']] = self.workflow_manage.get_reference_field(user_input_field['value'][0], + user_input_field['value'][1:]) + + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data, + message=str(question), **kwargs) + + def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/application_node/impl/__init__.py b/apps/application/flow/step_node/application_node/impl/__init__.py new file mode 100644 index 00000000000..e31a8d885cd --- /dev/null +++ b/apps/application/flow/step_node/application_node/impl/__init__.py @@ -0,0 +1,2 @@ +# coding=utf-8 +from .base_application_node import BaseApplicationNode diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py new file mode 100644 index 00000000000..7f4644a5815 --- /dev/null +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -0,0 +1,124 @@ +# coding=utf-8 +import json +import time +import uuid +from typing import List, Dict +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.application_node.i_application_node import IApplicationNode +from application.models import Chat +from common.handle.impl.response.openai_to_response import OpenaiToResponse + + +def string_to_uuid(input_str): + return str(uuid.uuid5(uuid.NAMESPACE_DNS, input_str)) + + +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): + result = node_variable.get('result') + node.context['message_tokens'] = result.get('usage', {}).get('prompt_tokens', 0) + node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0) + node.context['answer'] = answer + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): + workflow.answer += answer + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + usage = {} + for chunk in response: + # 先把流转成字符串 + response_content = chunk.decode('utf-8')[6:] + response_content = json.loads(response_content) + choices = response_content.get('choices') + if choices and isinstance(choices, list) and len(choices) > 0: + content = choices[0].get('delta', {}).get('content', '') + answer += content + yield content + usage = response_content.get('usage', {}) + node_variable['result'] = {'usage': usage} + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result')['choices'][0]['message'] + answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。" + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +class BaseApplicationNode(IApplicationNode): + + def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, + **kwargs) -> NodeResult: + from application.serializers.chat_message_serializers import ChatMessageSerializer + # 生成嵌入应用的chat_id + current_chat_id = string_to_uuid(chat_id + application_id) + Chat.objects.get_or_create(id=current_chat_id, defaults={ + 'application_id': application_id, + 'abstract': message + }) + response = ChatMessageSerializer( + data={'chat_id': current_chat_id, 'message': message, + 're_chat': re_chat, + 'stream': stream, + 'application_id': application_id, + 'client_id': client_id, + 'client_type': client_type, 'form_data': kwargs}).chat(base_to_response=OpenaiToResponse()) + if response.status_code == 200: + if stream: + content_generator = response.streaming_content + return NodeResult({'result': content_generator, 'question': message}, {}, + _write_context=write_context_stream) + else: + data = json.loads(response.content) + return NodeResult({'result': data, 'question': message}, {}, + _write_context=write_context) + + def get_details(self, index: int, **kwargs): + global_fields = [] + for api_input_field in self.node_params_serializer.data.get('api_input_field_list', []): + global_fields.append({ + 'label': api_input_field['variable'], + 'key': api_input_field['variable'], + 'value': self.workflow_manage.get_reference_field( + api_input_field['value'][0], + api_input_field['value'][1:]) + }) + for user_input_field in self.node_params_serializer.data.get('user_input_field_list', []): + global_fields.append({ + 'label': user_input_field['label'], + 'key': user_input_field['field'], + 'value': self.workflow_manage.get_reference_field( + user_input_field['value'][0], + user_input_field['value'][1:]) + }) + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "info": self.node.properties.get('node_data'), + 'run_time': self.context.get('run_time'), + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message, + 'global_fields': global_fields + } diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index d2e99bce85d..0e5cb15fbcb 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -53,7 +53,7 @@ def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwa self.__setattr__(keyword, kwargs.get(keyword)) -end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node'] +end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node'] class Flow: diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 0232313e824..86460f093bc 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -15,14 +15,11 @@ from functools import reduce from typing import Dict, List -from django.conf import settings from django.contrib.postgres.fields import ArrayField from django.core import cache, validators from django.core import signing -from django.core.paginator import Paginator from django.db import transaction, models from django.db.models import QuerySet, Q -from django.forms import CharField from django.http import HttpResponse from django.template import Template, Context from rest_framework import serializers @@ -46,10 +43,9 @@ from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list from embedding.models import SearchMode from function_lib.serializers.function_lib_serializer import FunctionLibSerializer -from setting.models import AuthOperate, TeamMember, TeamMemberPermission +from setting.models import AuthOperate from setting.models.model_management import Model from setting.models_provider import get_model_credential -from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.models_provider.tools import get_model_instance_by_model_user_id from setting.serializers.provider_serializers import ModelSerializer from smartdoc.conf import PROJECT_DIR @@ -979,6 +975,17 @@ def play_demo_text(self, form_data, with_valid=True): model = get_model_instance_by_model_user_id(tts_model_id, application.user_id, **form_data) return model.text_to_speech(text) + def application_list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + user_id = self.data.get('user_id') + application_id = self.data.get('application_id') + application = Application.objects.filter(user_id=user_id).exclude(id=application_id) + # 把应用的type为WORK_FLOW的应用放到最上面 然后再按名称排序 + serialized_data = ApplicationSerializerModel(application, many=True).data + application = sorted(serialized_data, key=lambda x: (x['type'] != 'WORK_FLOW', x['name'])) + return list(application) + class ApplicationKeySerializerModel(serializers.ModelSerializer): class Meta: model = ApplicationApiKey diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 488c244f973..61051f96d05 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -305,6 +305,8 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response): 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()), 'stream': stream, 're_chat': re_chat, + 'client_id': client_id, + 'client_type': client_type, 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), base_to_response, form_data) r = work_flow_manage.run() diff --git a/apps/application/urls.py b/apps/application/urls.py index b3df23d73a2..5bd551b7b58 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -22,6 +22,7 @@ path('application//function_lib', views.Application.FunctionLib.as_view()), path('application//function_lib/', views.Application.FunctionLib.Operate.as_view()), + path('application//application', views.Application.Application.as_view()), path('application//model_params_form/', views.Application.ModelParamsForm.as_view()), path('application//hit_test', views.Application.HitTest.as_view()), diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 64b6c367b0a..f0873d62c74 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -243,6 +243,24 @@ def get(self, request: Request, application_id: str, function_lib_id: str): data={'application_id': application_id, 'user_id': request.user.id}).get_function_lib(function_lib_id)) + class Application(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取当前人创建的应用列表", + operation_id="获取当前人创建的应用列表", + tags=["应用/会话"]) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, + 'user_id': request.user.id}).application_list()) + class Profile(APIView): authentication_classes = [TokenAuth] diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 922bbfc3c34..78eded77348 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -43,7 +43,7 @@ class ChatView(APIView): class Export(APIView): authentication_classes = [TokenAuth] - @action(methods=['GET'], detail=False) + @action(methods=['POST'], detail=False) @swagger_auto_schema(operation_summary="导出对话", operation_id="导出对话", manual_parameters=ChatApi.get_request_params_api(), diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index 92f9aaae49b..f41268e15ac 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -309,6 +309,18 @@ const listFunctionLib: (application_id: String, loading?: Ref) => Promi ) => { return get(`${prefix}/${application_id}/function_lib`, undefined, loading) } +/** + * 获取当前人的所有应用列表 + * @param application_id 应用id + * @param loading + * @returns + */ +export const getApplicationList: ( + application_id: string, + loading?: Ref +) => Promise> = (application_id, loading) => { + return get(`${prefix}/${application_id}/application`, undefined, loading) +} /** * 获取应用所属的函数库 * @param application_id @@ -500,5 +512,6 @@ export default { getWorkFlowVersionDetail, putWorkFlowVersion, playDemoText, - getUserList + getUserList, + getApplicationList } diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index ebee3cc628f..72dfbcb6cb2 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -17,7 +17,12 @@ - +

{{ item.name }}

@@ -37,7 +42,11 @@