diff --git a/api/configs/app_config.py b/api/configs/app_config.py index 61de73c8689f8b..07ef6121cc5040 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -27,7 +27,6 @@ class DifyConfig( # read from dotenv format config file env_file=".env", env_file_encoding="utf-8", - frozen=True, # ignore extra attributes extra="ignore", ) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index a91b9f0f020073..cdc82860c6cc4d 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -11,7 +11,7 @@ class ModelConfigConverter: @classmethod - def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: + def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -38,27 +38,23 @@ def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ) if model_credentials is None: - if not skip_check: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - else: - model_credentials = {} - - if not skip_check: - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_config.model, model_type=ModelType.LLM - ) - - if provider_model is None: - model_name = model_config.model - raise ValueError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_config.model, model_type=ModelType.LLM + ) + + if provider_model is None: + model_name = model_config.model + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config completion_params = model_config.parameters @@ -76,7 +72,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) - if not skip_check and not model_schema: + if not model_schema: raise ValueError(f"Model {model_name} not exist.") return ModelConfigWithCredentialsEntity( diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 46b86092770976..9d776f6337bccf 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -217,9 +217,12 @@ def _handle_workflow_run_failed( ).total_seconds() db.session.commit() - db.session.refresh(workflow_run) db.session.close() + with Session(db.engine, expire_on_commit=False) as session: + session.add(workflow_run) + session.refresh(workflow_run) + if trace_manager: trace_manager.add_trace_task( TraceTask( diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index eb260a8f84fbbd..6d8086435d5b29 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -3,7 +3,12 @@ from configs import dify_config from core.file import file_repository from core.helper import ssrf_proxy -from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent +from core.model_runtime.entities import ( + AudioPromptMessageContent, + DocumentPromptMessageContent, + ImagePromptMessageContent, + VideoPromptMessageContent, +) from extensions.ext_database import db from extensions.ext_storage import storage @@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute): return file.remote_url case FileAttribute.EXTENSION: return file.extension - case _: - raise ValueError(f"Invalid file attribute: {attr}") def to_prompt_message_content( f: File, /, *, - image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ): - """ - Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object. - - This function takes a File object and converts it to an appropriate PromptMessageContent - object, which can be used as a prompt for image or audio-based AI models. - - Args: - f (File): The File object to convert. - detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts. - If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW. - - Returns: - Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level - - Raises: - ValueError: If the file type is not supported or if required data is missing. - """ match f.type: case FileType.IMAGE: + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": data = _to_url(f) else: @@ -65,7 +52,7 @@ def to_prompt_message_content( return ImagePromptMessageContent(data=data, detail=image_detail_config) case FileType.AUDIO: - encoded_string = _file_to_encoded_string(f) + encoded_string = _get_encoded_string(f) if f.extension is None: raise ValueError("Missing file extension") return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) @@ -74,9 +61,20 @@ def to_prompt_message_content( data = _to_url(f) else: data = _to_base64_data_string(f) + if f.extension is None: + raise ValueError("Missing file extension") return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) + case FileType.DOCUMENT: + data = _get_encoded_string(f) + if f.mime_type is None: + raise ValueError("Missing file mime_type") + return DocumentPromptMessageContent( + encode_format="base64", + mime_type=f.mime_type, + data=data, + ) case _: - raise ValueError("file type f.type is not supported") + raise ValueError(f"file type {f.type} is not supported") def download(f: File, /): @@ -118,21 +116,16 @@ def _get_encoded_string(f: File, /): case FileTransferMethod.REMOTE_URL: response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() - content = response.content - encoded_string = base64.b64encode(content).decode("utf-8") - return encoded_string + data = response.content case FileTransferMethod.LOCAL_FILE: upload_file = file_repository.get_upload_file(session=db.session(), file=f) data = _download_file_content(upload_file.key) - encoded_string = base64.b64encode(data).decode("utf-8") - return encoded_string case FileTransferMethod.TOOL_FILE: tool_file = file_repository.get_tool_file(session=db.session(), file=f) data = _download_file_content(tool_file.file_key) - encoded_string = base64.b64encode(data).decode("utf-8") - return encoded_string - case _: - raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string def _to_base64_data_string(f: File, /): @@ -140,18 +133,6 @@ def _to_base64_data_string(f: File, /): return f"data:{f.mime_type};base64,{encoded_string}" -def _file_to_encoded_string(f: File, /): - match f.type: - case FileType.IMAGE: - return _to_base64_data_string(f) - case FileType.VIDEO: - return _to_base64_data_string(f) - case FileType.AUDIO: - return _get_encoded_string(f) - case _: - raise ValueError(f"file type {f.type} is not supported") - - def _to_url(f: File, /): if f.transfer_method == FileTransferMethod.REMOTE_URL: if f.remote_url is None: diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 688fb4776a86e1..282cd9b36fcf3b 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import Optional from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -27,7 +28,7 @@ def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> def get_history_prompt_messages( self, max_token_limit: int = 2000, message_limit: Optional[int] = None - ) -> list[PromptMessage]: + ) -> Sequence[PromptMessage]: """ Get history prompt messages. :param max_token_limit: max token limit diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 059ba6c3d1f26e..1986688551b601 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -100,10 +100,10 @@ def _get_load_balancing_manager( def invoke_llm( self, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: Optional[dict] = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 6bd9325785a2da..8870b3443536a9 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Sequence from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -31,7 +32,7 @@ def on_before_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -60,7 +61,7 @@ def on_new_chunk( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ): @@ -90,7 +91,7 @@ def on_after_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -120,7 +121,7 @@ def on_invoke_error( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index f5d4427e3e7a72..5e52f10b4c6ee2 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -2,6 +2,7 @@ from .message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, + DocumentPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContent, @@ -37,4 +38,5 @@ "LLMResultChunk", "LLMResultChunkDelta", "AudioPromptMessageContent", + "DocumentPromptMessageContent", ] diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 3c244d368ef78b..a7e3db0032626e 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,6 +1,7 @@ from abc import ABC +from collections.abc import Sequence from enum import Enum -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel, Field, field_validator @@ -57,6 +58,7 @@ class PromptMessageContentType(Enum): IMAGE = "image" AUDIO = "audio" VIDEO = "video" + DOCUMENT = "document" class PromptMessageContent(BaseModel): @@ -101,13 +103,20 @@ class DETAIL(str, Enum): detail: DETAIL = DETAIL.LOW +class DocumentPromptMessageContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.DOCUMENT + encode_format: Literal["base64"] + mime_type: str + data: str + + class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ role: PromptMessageRole - content: Optional[str | list[PromptMessageContent]] = None + content: Optional[str | Sequence[PromptMessageContent]] = None name: Optional[str] = None def is_empty(self) -> bool: diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 52ea787c3ad572..4e1ce17533bb95 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -87,6 +87,9 @@ class ModelFeature(Enum): AGENT_THOUGHT = "agent-thought" VISION = "vision" STREAM_TOOL_CALL = "stream-tool-call" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" class DefaultParameterName(str, Enum): diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 5b6f96129bde25..8faeffa872b40f 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import re import time from abc import abstractmethod -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from typing import Optional, Union from pydantic import ConfigDict @@ -48,7 +48,7 @@ def invoke( prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -169,7 +169,7 @@ def _code_block_mode_wrapper( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -212,7 +212,7 @@ def _code_block_mode_wrapper( ) model_parameters.pop("response_format") - stop = stop or [] + stop = list(stop) if stop is not None else [] stop.extend(["\n```", "```\n"]) block_prompts = block_prompts.replace("{{block}}", code_block) @@ -408,7 +408,7 @@ def _invoke_result_generator( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -479,7 +479,7 @@ def _invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> Union[LLMResult, Generator]: @@ -601,7 +601,7 @@ def _trigger_before_invoke_callbacks( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -647,7 +647,7 @@ def _trigger_new_chunk_callbacks( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -694,7 +694,7 @@ def _trigger_after_invoke_callbacks( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -742,7 +742,7 @@ def _trigger_invoke_error_callbacks( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml index e02c5517fe1f3c..4eb56bbc0e916e 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml index e20b8c4960734c..81822b162e6a16 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 4e7faab891e37a..79701e4ea4f547 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,7 +1,7 @@ import base64 import io import json -from collections.abc import Generator +from collections.abc import Generator, Sequence from typing import Optional, Union, cast import anthropic @@ -21,9 +21,9 @@ from PIL import Image from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + DocumentPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -33,6 +33,7 @@ ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -86,10 +87,10 @@ def _chat_generate( self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> Union[LLMResult, Generator]: @@ -130,9 +131,17 @@ def _chat_generate( # Add the new header for claude-3-5-sonnet-20240620 model extra_headers = {} if model == "claude-3-5-sonnet-20240620": - if model_parameters.get("max_tokens") > 4096: + if model_parameters.get("max_tokens", 0) > 4096: extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" + if any( + isinstance(content, DocumentPromptMessageContent) + for prompt_message in prompt_messages + if isinstance(prompt_message.content, list) + for content in prompt_message.content + ): + extra_headers["anthropic-beta"] = "pdfs-2024-09-25" + if tools: extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] response = client.beta.tools.messages.create( @@ -504,6 +513,21 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) + elif isinstance(message_content, DocumentPromptMessageContent): + if message_content.mime_type != "application/pdf": + raise ValueError( + f"Unsupported document type {message_content.mime_type}, " + "only support application/pdf" + ) + sub_message_dict = { + "type": "document", + "source": { + "type": message_content.encode_format, + "media_type": message_content.mime_type, + "data": message_content.data, + }, + } + sub_messages.append(sub_message_dict) prompt_message_dicts.append({"role": "user", "content": sub_messages}) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml index e07dea2ee17bce..6571cd094fc36b 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml @@ -7,6 +7,7 @@ features: - multi-tool-call - agent-thought - stream-tool-call + - audio model_properties: mode: chat context_size: 128000 diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 5eec5e3c99a00f..aa175153bc633f 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import cast from core.model_runtime.entities import ( @@ -14,7 +15,7 @@ class PromptMessageUtil: @staticmethod - def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: + def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]: """ Prompt messages to prompt for saving. :param model_mode: model mode diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index b71882b043ecdf..69bd5567a46a99 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -118,11 +118,11 @@ def markdown(self) -> str: @property def log(self) -> str: - return str(self.value) + return "" @property def text(self) -> str: - return str(self.value) + return "" class ArrayAnySegment(ArraySegment): @@ -155,3 +155,11 @@ def markdown(self) -> str: for item in self.value: items.append(item.markdown) return "\n".join(items) + + @property + def log(self) -> str: + return "" + + @property + def text(self) -> str: + return "" diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index a25d563fe0b809..19a66087f7d175 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -39,7 +39,14 @@ def convert_none_configs(cls, v: Any): class PromptConfig(BaseModel): - jinja2_variables: Optional[list[VariableSelector]] = None + jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) + + @field_validator("jinja2_variables", mode="before") + @classmethod + def convert_none_jinja2_variables(cls, v: Any): + if v is None: + return [] + return v class LLMNodeChatModelMessage(ChatModelMessage): @@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - prompt_config: Optional[PromptConfig] = None + prompt_config: PromptConfig = Field(default_factory=PromptConfig) memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) + + @field_validator("prompt_config", mode="before") + @classmethod + def convert_none_prompt_config(cls, v: Any): + if v is None: + return PromptConfig() + return v diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py index f858be25156951..b5207d5573e454 100644 --- a/api/core/workflow/nodes/llm/exc.py +++ b/api/core/workflow/nodes/llm/exc.py @@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError): class NoPromptFoundError(LLMNodeError): """Raised when no prompt is found in the LLM configuration.""" + + +class NotSupportedPromptTypeError(LLMNodeError): + """Raised when the prompt type is not supported.""" + + +class MemoryRolePrefixRequiredError(LLMNodeError): + """Raised when memory role prefix is required for completion model.""" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index eb4d1c9d87aa6a..0cb53ee9d32168 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,4 +1,5 @@ import json +import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast @@ -6,21 +7,26 @@ from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.file import FileType, file_manager +from core.helper.code_executor import CodeExecutor, CodeLanguage from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities import ( - AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, TextPromptMessageContent, - VideoPromptMessageContent, ) from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageRole, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.variables import ( @@ -32,8 +38,9 @@ ObjectSegment, StringSegment, ) -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base import BaseNode @@ -62,14 +69,18 @@ InvalidVariableTypeError, LLMModeRequiredError, LLMNodeError, + MemoryRolePrefixRequiredError, ModelNotExistError, NoPromptFoundError, + NotSupportedPromptTypeError, VariableNotFoundError, ) if TYPE_CHECKING: from core.file.models import File +logger = logging.getLogger(__name__) + class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData @@ -123,17 +134,13 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] # fetch prompt messages if self.node_data.memory: - query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) - if not query: - raise VariableNotFoundError("Query not found") - query = query.text + query = self.node_data.memory.query_prompt_template else: query = None prompt_messages, stop = self._fetch_prompt_messages( - system_query=query, - inputs=inputs, - files=files, + user_query=query, + user_files=files, context=context, memory=memory, model_config=model_config, @@ -141,6 +148,8 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, vision_detail=self.node_data.vision.configs.detail, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=self.node_data.prompt_config.jinja2_variables, ) process_data = { @@ -181,6 +190,17 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] ) ) return + except Exception as e: + logger.exception(f"Node {self.node_id} failed to run") + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data, + ) + ) + return outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} @@ -203,8 +223,8 @@ def _invoke_llm( self, node_data_model: ModelConfig, model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None, + prompt_messages: Sequence[PromptMessage], + stop: Optional[Sequence[str]] = None, ) -> Generator[NodeEvent, None, None]: db.session.close() @@ -519,9 +539,8 @@ def _fetch_memory( def _fetch_prompt_messages( self, *, - system_query: str | None = None, - inputs: dict[str, str] | None = None, - files: Sequence["File"], + user_query: str | None = None, + user_files: Sequence["File"], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -529,58 +548,146 @@ def _fetch_prompt_messages( memory_config: MemoryConfig | None = None, vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: - inputs = inputs or {} - - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs=inputs, - query=system_query or "", - files=files, - context=context, - memory_config=memory_config, - memory=memory, - model_config=model_config, - ) - stop = model_config.stop + variable_pool: VariablePool, + jinja2_variables: Sequence[VariableSelector], + ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: + prompt_messages = [] + + if isinstance(prompt_template, list): + # For chat model + prompt_messages.extend( + _handle_list_messages( + messages=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail, + ) + ) + + # Get memory messages for chat mode + memory_messages = _handle_memory_chat_mode( + memory=memory, + memory_config=memory_config, + model_config=model_config, + ) + # Extend prompt_messages with memory messages + prompt_messages.extend(memory_messages) + + # Add current query to the prompt messages + if user_query: + message = LLMNodeChatModelMessage( + text=user_query, + role=PromptMessageRole.USER, + edition_type="basic", + ) + prompt_messages.extend( + _handle_list_messages( + messages=[message], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=vision_detail, + ) + ) + + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + # For completion model + prompt_messages.extend( + _handle_completion_template( + template=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + ) + + # Get memory text for completion model + memory_text = _handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_config=model_config, + ) + # Insert histories into the prompt + prompt_content = prompt_messages[0].content + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + + # Add current query to the prompt message + if user_query: + prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query) + prompt_messages[0].content = prompt_content + else: + errmsg = f"Prompt type {type(prompt_template)} is not supported" + logger.warning(errmsg) + raise NotSupportedPromptTypeError(errmsg) + + if vision_enabled and user_files: + file_prompts = [] + for file in user_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # Filter prompt messages filtered_prompt_messages = [] for prompt_message in prompt_messages: - if prompt_message.is_empty(): - continue - - if not isinstance(prompt_message.content, str): + if isinstance(prompt_message.content, list): prompt_message_content = [] - for content_item in prompt_message.content or []: - # Skip image if vision is disabled - if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: + for content_item in prompt_message.content: + # Skip content if features are not defined + if not model_config.model_schema.features: + if content_item.type != PromptMessageContentType.TEXT: + continue + prompt_message_content.append(content_item) continue - if isinstance(content_item, ImagePromptMessageContent): - # Override vision config if LLM node has vision config, - # cuz vision detail is related to the configuration from FileUpload feature. - content_item.detail = vision_detail - prompt_message_content.append(content_item) - elif isinstance( - content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent + # Skip content if corresponding feature is not supported + if ( + ( + content_item.type == PromptMessageContentType.IMAGE + and ModelFeature.VISION not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.DOCUMENT + and ModelFeature.DOCUMENT not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.VIDEO + and ModelFeature.VIDEO not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.AUDIO + and ModelFeature.AUDIO not in model_config.model_schema.features + ) ): - prompt_message_content.append(content_item) - - if len(prompt_message_content) > 1: - prompt_message.content = prompt_message_content - elif ( - len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT - ): + continue + prompt_message_content.append(content_item) + if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: prompt_message.content = prompt_message_content[0].data - + else: + prompt_message.content = prompt_message_content + if prompt_message.is_empty(): + continue filtered_prompt_messages.append(prompt_message) - if not filtered_prompt_messages: + if len(filtered_prompt_messages) == 0: raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) + stop = model_config.stop return filtered_prompt_messages, stop @classmethod @@ -715,3 +822,198 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: } }, } + + +def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): + match role: + case PromptMessageRole.USER: + return UserPromptMessage(content=[TextPromptMessageContent(data=text)]) + case PromptMessageRole.ASSISTANT: + return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)]) + case PromptMessageRole.SYSTEM: + return SystemPromptMessage(content=[TextPromptMessageContent(data=text)]) + raise NotImplementedError(f"Role {role} is not supported") + + +def _render_jinja2_message( + *, + template: str, + jinjia2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, +): + if not template: + return "" + + jinjia2_inputs = {} + for jinja2_variable in jinjia2_variables: + variable = variable_pool.get(jinja2_variable.value_selector) + jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + code_execute_resp = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=jinjia2_inputs, + ) + result_text = code_execute_resp["result"] + return result_text + + +def _handle_list_messages( + *, + messages: Sequence[LLMNodeChatModelMessage], + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + vision_detail_config: ImagePromptMessageContent.DETAIL, +) -> Sequence[PromptMessage]: + prompt_messages = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = variable_pool.convert_template(template) + + # Process segments for images + file_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + if isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = UserPromptMessage(content=file_contents) + prompt_messages.append(prompt_message) + + return prompt_messages + + +def _calculate_rest_token( + *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity +) -> int: + rest_tokens = 2000 + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + +def _handle_memory_chat_mode( + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, +) -> Sequence[PromptMessage]: + memory_messages = [] + # Get messages from memory for chat model + if memory and memory_config: + rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + return memory_messages + + +def _handle_memory_completion_mode( + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, +) -> str: + memory_text = "" + # Get history text from memory for completion model + if memory and memory_config: + rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + memory_text = memory.get_history_prompt_text( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + return memory_text + + +def _handle_completion_template( + *, + template: LLMNodeCompletionModelPromptTemplate, + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, +) -> Sequence[PromptMessage]: + """Handle completion template processing outside of LLMNode class. + + Args: + template: The completion model prompt template + context: Optional context string + jinja2_variables: Variables for jinja2 template rendering + variable_pool: Variable pool for template conversion + + Returns: + Sequence of prompt messages + """ + prompt_messages = [] + if template.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=template.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + else: + if context: + template_text = template.text.replace("{#context#}", context) + else: + template_text = template.text + result_text = variable_pool.convert_template(template_text).text + prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) + prompt_messages.append(prompt_message) + return prompt_messages diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 744dfd3d8d656b..e855ab2d2b0659 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -86,12 +86,14 @@ def _run(self): ) prompt_messages, stop = self._fetch_prompt_messages( prompt_template=prompt_template, - system_query=query, + user_query=query, memory=memory, model_config=model_config, - files=files, + user_files=files, vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, + variable_pool=variable_pool, + jinja2_variables=[], ) # handle invoke result diff --git a/api/poetry.lock b/api/poetry.lock index 6021ae5c740ab7..6d3d2d5a7fa11d 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -2423,6 +2423,21 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "faker" +version = "32.1.0" +description = "Faker is a Python package that generates fake data for you." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"}, + {file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"}, +] + +[package.dependencies] +python-dateutil = ">=2.4" +typing-extensions = "*" + [[package]] name = "fal-client" version = "0.5.6" @@ -11041,4 +11056,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "69a3f471f85dce9e5fb889f739e148a4a6d95aaf94081414503867c7157dba69" +content-hash = "d149b24ce7a203fa93eddbe8430d8ea7e5160a89c8d348b1b747c19899065639" diff --git a/api/pyproject.toml b/api/pyproject.toml index 0d87c1b1c8988f..2547dab7a021b9 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -268,6 +268,7 @@ weaviate-client = "~3.21.0" optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" +faker = "~32.1.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py index 85a4f7734dc47c..b995077984e910 100644 --- a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py @@ -11,7 +11,6 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel -from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock @pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py index 466facc5fffcf6..4d72327c0ec43c 100644 --- a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py @@ -4,29 +4,21 @@ from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel +from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel def test_validate_credentials(): - model = AzureAIStudioRerankModel() + model = AzureRerankModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="azure-ai-studio-rerank-v1", credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, - query="What is the capital of the United States?", - docs=[ - "Carson City is the capital city of the American state of Nevada. At the 2010 United States " - "Census, Carson City had a population of 55,274.", - "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " - "are a political division controlled by the United States. Its capital is Saipan.", - ], - score_threshold=0.8, ) def test_invoke_model(): - model = AzureAIStudioRerankModel() + model = AzureRerankModel() result = model.invoke( model="azure-ai-studio-rerank-v1", diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index def6c2a2325a0f..a1f9ece0d10d37 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -1,125 +1,484 @@ +from collections.abc import Sequence +from typing import Optional + import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration from core.file import File, FileTransferMethod, FileType -from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel +from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.end import EndStreamParam -from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, + VisionConfigOptions, +) from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom +from models.provider import ProviderType from models.workflow import WorkflowType +from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario -class TestLLMNode: - @pytest.fixture - def llm_node(self): - data = LLMNodeData( - title="Test LLM", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[], - memory=None, - context=ContextConfig(enabled=False), - vision=VisionConfig( - enabled=True, - configs=VisionConfigOptions( - variable_selector=["sys", "files"], - detail=ImagePromptMessageContent.DETAIL.HIGH, - ), - ), - ) - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - node = LLMNode( - id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, +class MockTokenBufferMemory: + def __init__(self, history_messages=None): + self.history_messages = history_messages or [] + + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> Sequence[PromptMessage]: + if message_limit is not None: + return self.history_messages[-message_limit * 2 :] + return self.history_messages + + +@pytest.fixture +def llm_node(): + data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[], + memory=None, + context=ContextConfig(enabled=False), + vision=VisionConfig( + enabled=True, + configs=VisionConfigOptions( + variable_selector=["sys", "files"], + detail=ImagePromptMessageContent.DETAIL.HIGH, ), - graph=Graph( - root_node_id="1", - answer_stream_generate_routes=AnswerStreamGenerateRoute( - answer_dependencies={}, - answer_generate_route={}, - ), - end_stream_param=EndStreamParam( - end_dependencies={}, - end_stream_variable_selector_mapping={}, - ), + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + node = LLMNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, ), - ) - return node + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + return node + + +@pytest.fixture +def model_config(): + # Create actual provider and model type instances + model_provider_factory = ModelProviderFactory() + provider_instance = model_provider_factory.get_provider_instance("openai") + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + + # Create a ProviderModelBundle + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id="1", + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=None), + model_settings=[], + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance, + ) - def test_fetch_files_with_file_segment(self, llm_node): - file = File( + # Create and return a ModelConfigWithCredentialsEntity + return ModelConfigWithCredentialsEntity( + provider="openai", + model="gpt-3.5-turbo", + model_schema=AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ), + mode="chat", + credentials={}, + parameters={}, + provider_model_bundle=provider_model_bundle, + ) + + +def test_fetch_files_with_file_segment(llm_node): + file = File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + ) + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [file] + + +def test_fetch_files_with_array_file_segment(llm_node): + files = [ + File( id="1", tenant_id="test", type=FileType.IMAGE, - filename="test.jpg", + filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", + ), + File( + id="2", + tenant_id="test", + type=FileType.IMAGE, + filename="test2.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="2", + ), + ] + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == files + + +def test_fetch_files_with_none_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_array_any_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_non_existent_variable(llm_node): + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): + prompt_template = [] + llm_node.node_data.prompt_template = prompt_template + + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, ) - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) - - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [file] - - def test_fetch_files_with_array_file_segment(self, llm_node): - files = [ - File( - id="1", - tenant_id="test", - type=FileType.IMAGE, - filename="test1.jpg", - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="1", - ), - File( - id="2", - tenant_id="test", - type=FileType.IMAGE, - filename="test2.jpg", - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="2", - ), - ] - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + ] + + fake_query = faker.sentence() + + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=fake_query, + user_files=files, + context=None, + memory=None, + model_config=model_config, + prompt_template=prompt_template, + memory_config=None, + vision_enabled=False, + vision_detail=fake_vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) + + assert prompt_messages == [UserPromptMessage(content=fake_query)] + + +def test_fetch_prompt_messages__basic(faker, llm_node, model_config): + # Setup dify config + dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" + dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url" + + # Generate fake values for prompt template + fake_assistant_prompt = faker.sentence() + fake_query = faker.sentence() + fake_context = faker.sentence() + fake_window_size = faker.random_int(min=1, max=3) + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() + + # Setup mock memory with history messages + mock_history = [ + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + ] - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == files + # Setup memory configuration + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size), + query_prompt_template=None, + ) - def test_fetch_files_with_none_segment(self, llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) + memory = MockTokenBufferMemory(history_messages=mock_history) - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] + # Test scenarios covering different file input combinations + test_scenarios = [ + LLMNodeTestScenario( + description="No files", + user_query=fake_query, + user_files=[], + features=[], + vision_enabled=False, + vision_detail=None, + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + expected_messages=[ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), + ] + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage(content=fake_query), + ], + ), + LLMNodeTestScenario( + description="User files", + user_query=fake_query, + user_files=[ + File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + ], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + expected_messages=[ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), + ] + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data=fake_query), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ], + ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File", + user_query=fake_query, + user_files=[], + vision_enabled=False, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=[ + UserPromptMessage( + content=[ + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ] + + mock_history[fake_window_size * -2 :] + + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + }, + ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File without vision feature", + user_query=fake_query, + user_files=[], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + }, + ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File with video file and vision feature", + user_query=fake_query, + user_files=[], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.VIDEO, + filename="test1.mp4", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + extension="mp4", + ) + }, + ), + ] - def test_fetch_files_with_array_any_segment(self, llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + for scenario in test_scenarios: + model_config.model_schema.features = scenario.features - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] + for k, v in scenario.file_variables.items(): + selector = k.split(".") + llm_node.graph_runtime_state.variable_pool.add(selector, v) + + # Call the method under test + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=scenario.user_query, + user_files=scenario.user_files, + context=fake_context, + memory=memory, + model_config=model_config, + prompt_template=scenario.prompt_template, + memory_config=memory_config, + vision_enabled=scenario.vision_enabled, + vision_detail=scenario.vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) - def test_fetch_files_with_non_existent_variable(self, llm_node): - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] + # Verify the result + assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}" + assert ( + prompt_messages == scenario.expected_messages + ), f"Message content mismatch in scenario: {scenario.description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py new file mode 100644 index 00000000000000..8e39445baf5490 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -0,0 +1,25 @@ +from collections.abc import Mapping, Sequence + +from pydantic import BaseModel, Field + +from core.file import File +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelFeature +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage + + +class LLMNodeTestScenario(BaseModel): + """Test scenario for LLM node testing.""" + + description: str = Field(..., description="Description of the test scenario") + user_query: str = Field(..., description="User query input") + user_files: Sequence[File] = Field(default_factory=list, description="List of user files") + vision_enabled: bool = Field(default=False, description="Whether vision is enabled") + vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") + features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") + window_size: int = Field(..., description="Window size for memory") + prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") + file_variables: Mapping[str, File | Sequence[File]] = Field( + default_factory=dict, description="List of file variables" + ) + expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") diff --git a/web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx b/web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx index 6ca3af958a18bf..db2425c958b54e 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx @@ -160,6 +160,7 @@ const CodeEditor: FC = ({ hideSearch vars={availableVars} onChange={handleSelectVar} + isSupportFileVar={false} /> )} diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx index c447bb463b2856..fe2bb209877c86 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx @@ -18,6 +18,7 @@ type Props = { isSupportConstantValue?: boolean onlyLeafNodeVar?: boolean filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean + isSupportFileVar?: boolean } const VarList: FC = ({ @@ -29,6 +30,7 @@ const VarList: FC = ({ isSupportConstantValue, onlyLeafNodeVar, filterVar, + isSupportFileVar = true, }) => { const { t } = useTranslation() @@ -94,6 +96,7 @@ const VarList: FC = ({ defaultVarKindType={item.variable_type} onlyLeafNodeVar={onlyLeafNodeVar} filterVar={filterVar} + isSupportFileVar={isSupportFileVar} /> {!readonly && ( void typePlaceHolder?: string + isSupportFileVar?: boolean } const VarReferencePicker: FC = ({ @@ -81,6 +82,7 @@ const VarReferencePicker: FC = ({ isInTable, onRemove, typePlaceHolder, + isSupportFileVar = true, }) => { const { t } = useTranslation() const store = useStoreApi() @@ -382,6 +384,7 @@ const VarReferencePicker: FC = ({ vars={outputVars} onChange={handleVarReferenceChange} itemWidth={isAddBtnTrigger ? 260 : triggerWidth} + isSupportFileVar={isSupportFileVar} /> )} diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx index 8ee9698745d096..cd03da1556f075 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx @@ -8,11 +8,13 @@ type Props = { vars: NodeOutPutVar[] onChange: (value: ValueSelector, varDetail: Var) => void itemWidth?: number + isSupportFileVar?: boolean } const VarReferencePopup: FC = ({ vars, onChange, itemWidth, + isSupportFileVar = true, }) => { // max-h-[300px] overflow-y-auto todo: use portal to handle long list return ( @@ -24,7 +26,7 @@ const VarReferencePopup: FC = ({ vars={vars} onChange={onChange} itemWidth={itemWidth} - isSupportFileVar + isSupportFileVar={isSupportFileVar} /> ) diff --git a/web/app/components/workflow/nodes/code/panel.tsx b/web/app/components/workflow/nodes/code/panel.tsx index 08fc565836b3a8..a0027daf53ae06 100644 --- a/web/app/components/workflow/nodes/code/panel.tsx +++ b/web/app/components/workflow/nodes/code/panel.tsx @@ -89,6 +89,7 @@ const Panel: FC> = ({ list={inputs.variables} onChange={handleVarListChange} filterVar={filterVar} + isSupportFileVar={false} /> diff --git a/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx b/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx index c8d4d92fda9eeb..d8d47a157f2977 100644 --- a/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx +++ b/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx @@ -144,6 +144,7 @@ const ConfigPromptItem: FC = ({ onEditionTypeChange={onEditionTypeChange} varList={varList} handleAddVariable={handleAddVariable} + isSupportFileVar /> ) } diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index 76607b29b12e04..ca54d4e48729dd 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -67,6 +67,7 @@ const Panel: FC> = ({ handleStop, varInputs, runResult, + filterJinjia2InputVar, } = useConfig(id, data) const model = inputs.model @@ -194,7 +195,8 @@ const Panel: FC> = ({ list={inputs.prompt_config?.jinja2_variables || []} onChange={handleVarListChange} onVarNameChange={handleVarNameChange} - filterVar={filterVar} + filterVar={filterJinjia2InputVar} + isSupportFileVar={false} /> )} @@ -233,6 +235,7 @@ const Panel: FC> = ({ hasSetBlockStatus={hasSetBlockStatus} nodesOutputVars={availableVars} availableNodes={availableNodesWithParent} + isSupportFileVar /> {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index 33742b072618e2..ee9f2ca9153005 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -278,11 +278,15 @@ const useConfig = (id: string, payload: LLMNodeType) => { }, [inputs, setInputs]) const filterInputVar = useCallback((varPayload: Var) => { + return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type) + }, []) + + const filterJinjia2InputVar = useCallback((varPayload: Var) => { return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) }, []) const filterMemoryPromptVar = useCallback((varPayload: Var) => { - return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) + return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type) }, []) const { @@ -406,6 +410,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { handleRun, handleStop, runResult, + filterJinjia2InputVar, } } diff --git a/web/app/components/workflow/nodes/template-transform/panel.tsx b/web/app/components/workflow/nodes/template-transform/panel.tsx index c02b895840b341..9eb52c55724eb3 100644 --- a/web/app/components/workflow/nodes/template-transform/panel.tsx +++ b/web/app/components/workflow/nodes/template-transform/panel.tsx @@ -64,6 +64,7 @@ const Panel: FC> = ({ onChange={handleVarListChange} onVarNameChange={handleVarNameChange} filterVar={filterVar} + isSupportFileVar={false} />