Skip to content

fix: 修复工作流编排历史数据模型参数必填 #1071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ChatNodeSerializer(serializers.Serializer):

is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))

model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置"))
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置"))


class IChatNode(INode):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
from functools import reduce
from typing import List, Dict

from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage

from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from setting.models import Model
from setting.models_provider import get_model_credential
from setting.models_provider.tools import get_model_instance_by_model_user_id


Expand Down Expand Up @@ -61,11 +64,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
_write_context(node_variable, workflow_variable, node, workflow, answer)


def get_default_model_params_setting(model_id):
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(
model.model_name).get_default_form_data()
return model_params_setting


class BaseChatNode(IChatNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting,
**kwargs) -> NodeResult:

if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class QuestionNodeSerializer(serializers.Serializer):
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))

is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
model_params_setting = serializers.DictField(required=True, error_messages=ErrMessage.integer("模型参数相关设置"))
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置"))


class IQuestionNode(INode):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
from functools import reduce
from typing import List, Dict

from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage

from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.question_node.i_question_node import IQuestionNode
from setting.models import Model
from setting.models_provider import get_model_credential
from setting.models_provider.tools import get_model_instance_by_model_user_id


Expand Down Expand Up @@ -61,10 +64,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
_write_context(node_variable, workflow_variable, node, workflow, answer)


def get_default_model_params_setting(model_id):
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(
model.model_name).get_default_form_data()
return model_params_setting


class BaseQuestionNode(IQuestionNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting,
**kwargs) -> NodeResult:
if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)
Expand Down
Loading