diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index be89f35ef2f..15117c0b700 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -24,7 +24,7 @@ class ChatNodeSerializer(serializers.Serializer): is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) - model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置")) + model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置")) class IChatNode(INode): diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 3581b5cfd9d..f8ec8a0118d 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -10,11 +10,14 @@ from functools import reduce from typing import List, Dict +from django.db.models import QuerySet from langchain.schema import HumanMessage, SystemMessage from langchain_core.messages import BaseMessage from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode +from setting.models import Model +from setting.models_provider import get_model_credential from setting.models_provider.tools import get_model_instance_by_model_user_id @@ -61,11 +64,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor _write_context(node_variable, workflow_variable, node, workflow, answer) +def get_default_model_params_setting(model_id): + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = credential.get_model_params_setting_form( + model.model_name).get_default_form_data() + return model_params_setting + + class BaseChatNode(IChatNode): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting, **kwargs) -> NodeResult: - + if model_params_setting is None: + model_params_setting = get_default_model_params_setting(model_id) chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number) diff --git a/apps/application/flow/step_node/question_node/i_question_node.py b/apps/application/flow/step_node/question_node/i_question_node.py index 8fda34a5d78..9b8c125621b 100644 --- a/apps/application/flow/step_node/question_node/i_question_node.py +++ b/apps/application/flow/step_node/question_node/i_question_node.py @@ -23,7 +23,7 @@ class QuestionNodeSerializer(serializers.Serializer): dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) - model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置")) + model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置")) class IQuestionNode(INode): diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index 397a0bfc181..33e1c0fe468 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -10,11 +10,14 @@ from functools import reduce from typing import List, Dict +from django.db.models import QuerySet from langchain.schema import HumanMessage, SystemMessage from langchain_core.messages import BaseMessage from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.question_node.i_question_node import IQuestionNode +from setting.models import Model +from setting.models_provider import get_model_credential from setting.models_provider.tools import get_model_instance_by_model_user_id @@ -61,10 +64,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor _write_context(node_variable, workflow_variable, node, workflow, answer) +def get_default_model_params_setting(model_id): + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = credential.get_model_params_setting_form( + model.model_name).get_default_form_data() + return model_params_setting + + class BaseQuestionNode(IQuestionNode): def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting, **kwargs) -> NodeResult: + if model_params_setting is None: + model_params_setting = get_default_model_params_setting(model_id) chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number)