From 3c33c395267eda7bb2e8faa0c8fb10bb5a258fb8 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 14:50:36 +0800 Subject: [PATCH 01/36] chore(deps): add faker --- api/poetry.lock | 17 ++++++++++++++++- api/pyproject.toml | 1 + 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/api/poetry.lock b/api/poetry.lock index 6021ae5c740ab7..56c2010c36330a 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -2423,6 +2423,21 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "faker" +version = "32.1.0" +description = "Faker is a Python package that generates fake data for you." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"}, + {file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"}, +] + +[package.dependencies] +python-dateutil = ">=2.4" +typing-extensions = "*" + [[package]] name = "fal-client" version = "0.5.6" @@ -11041,4 +11056,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "69a3f471f85dce9e5fb889f739e148a4a6d95aaf94081414503867c7157dba69" +content-hash = "883ea70052db8eecb4f790123d228984b9a2d8fcafe328d9e21259917e75456f" diff --git a/api/pyproject.toml b/api/pyproject.toml index 0d87c1b1c8988f..e2d3bb6591237e 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -268,6 +268,7 @@ weaviate-client = "~3.21.0" optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" +faker = "^32.1.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" From c8330e079d304d73457adfe75f7efa76eade0f0f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 16:07:22 +0800 Subject: [PATCH 02/36] refactor(converter): simplify model credentials validation logic --- .../model_config/converter.py | 42 +++++++++---------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index a91b9f0f020073..cdc82860c6cc4d 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -11,7 +11,7 @@ class ModelConfigConverter: @classmethod - def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: + def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -38,27 +38,23 @@ def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ) if model_credentials is None: - if not skip_check: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - else: - model_credentials = {} - - if not skip_check: - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_config.model, model_type=ModelType.LLM - ) - - if provider_model is None: - model_name = model_config.model - raise ValueError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_config.model, model_type=ModelType.LLM + ) + + if provider_model is None: + model_name = model_config.model + raise ValueError(f"Model {model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config completion_params = model_config.parameters @@ -76,7 +72,7 @@ def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) - if not skip_check and not model_schema: + if not model_schema: raise ValueError(f"Model {model_name} not exist.") return ModelConfigWithCredentialsEntity( From 61ea2dda254d2b18b104a3170f7ba9138e62448f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 16:21:29 +0800 Subject: [PATCH 03/36] refactor: update stop parameter type to use Sequence instead of list --- api/core/model_manager.py | 2 +- .../model_runtime/callbacks/base_callback.py | 9 +++++---- .../__base/large_language_model.py | 20 +++++++++---------- api/core/workflow/nodes/llm/node.py | 2 +- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 059ba6c3d1f26e..3424a7fa780b62 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -103,7 +103,7 @@ def invoke_llm( prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 6bd9325785a2da..8870b3443536a9 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Sequence from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -31,7 +32,7 @@ def on_before_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -60,7 +61,7 @@ def on_new_chunk( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ): @@ -90,7 +91,7 @@ def on_after_invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -120,7 +121,7 @@ def on_invoke_error( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 5b6f96129bde25..8faeffa872b40f 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import re import time from abc import abstractmethod -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from typing import Optional, Union from pydantic import ConfigDict @@ -48,7 +48,7 @@ def invoke( prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -169,7 +169,7 @@ def _code_block_mode_wrapper( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -212,7 +212,7 @@ def _code_block_mode_wrapper( ) model_parameters.pop("response_format") - stop = stop or [] + stop = list(stop) if stop is not None else [] stop.extend(["\n```", "```\n"]) block_prompts = block_prompts.replace("{{block}}", code_block) @@ -408,7 +408,7 @@ def _invoke_result_generator( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -479,7 +479,7 @@ def _invoke( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> Union[LLMResult, Generator]: @@ -601,7 +601,7 @@ def _trigger_before_invoke_callbacks( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -647,7 +647,7 @@ def _trigger_new_chunk_callbacks( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -694,7 +694,7 @@ def _trigger_after_invoke_callbacks( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -742,7 +742,7 @@ def _trigger_invoke_error_callbacks( prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index eb4d1c9d87aa6a..7634b90dfffb83 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -204,7 +204,7 @@ def _invoke_llm( node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, ) -> Generator[NodeEvent, None, None]: db.session.close() From 3687ea65d2f261d6c50c2ec75ed605e2dcb5b4b1 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 16:22:01 +0800 Subject: [PATCH 04/36] refactor: update jinja2_variables and prompt_config to use Sequence and add validators for None handling --- api/core/workflow/nodes/llm/entities.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index a25d563fe0b809..19a66087f7d175 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -39,7 +39,14 @@ def convert_none_configs(cls, v: Any): class PromptConfig(BaseModel): - jinja2_variables: Optional[list[VariableSelector]] = None + jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) + + @field_validator("jinja2_variables", mode="before") + @classmethod + def convert_none_jinja2_variables(cls, v: Any): + if v is None: + return [] + return v class LLMNodeChatModelMessage(ChatModelMessage): @@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - prompt_config: Optional[PromptConfig] = None + prompt_config: PromptConfig = Field(default_factory=PromptConfig) memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) + + @field_validator("prompt_config", mode="before") + @classmethod + def convert_none_prompt_config(cls, v: Any): + if v is None: + return PromptConfig() + return v From 223e03a6fdd7ed03eda30c019a9679ec46b4e8d3 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 16:22:15 +0800 Subject: [PATCH 05/36] feat(errors): add new error classes for unsupported prompt types and memory role prefix requirements --- api/core/workflow/nodes/llm/exc.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py index f858be25156951..b5207d5573e454 100644 --- a/api/core/workflow/nodes/llm/exc.py +++ b/api/core/workflow/nodes/llm/exc.py @@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError): class NoPromptFoundError(LLMNodeError): """Raised when no prompt is found in the LLM configuration.""" + + +class NotSupportedPromptTypeError(LLMNodeError): + """Raised when the prompt type is not supported.""" + + +class MemoryRolePrefixRequiredError(LLMNodeError): + """Raised when memory role prefix is required for completion model.""" From bd60d0f1e521245a6096589f3e248bd0b8744afa Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 13 Nov 2024 16:32:43 +0800 Subject: [PATCH 06/36] fix(tests): update Azure Rerank Model usage and clean imports --- .../model_runtime/azure_ai_studio/test_llm.py | 1 - .../model_runtime/azure_ai_studio/test_rerank.py | 14 +++----------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py index 85a4f7734dc47c..b995077984e910 100644 --- a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py @@ -11,7 +11,6 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel -from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock @pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py index 466facc5fffcf6..4d72327c0ec43c 100644 --- a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py @@ -4,29 +4,21 @@ from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel +from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel def test_validate_credentials(): - model = AzureAIStudioRerankModel() + model = AzureRerankModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="azure-ai-studio-rerank-v1", credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, - query="What is the capital of the United States?", - docs=[ - "Carson City is the capital city of the American state of Nevada. At the 2010 United States " - "Census, Carson City had a population of 55,274.", - "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " - "are a political division controlled by the United States. Its capital is Saipan.", - ], - score_threshold=0.8, ) def test_invoke_model(): - model = AzureAIStudioRerankModel() + model = AzureRerankModel() result = model.invoke( model="azure-ai-studio-rerank-v1", From 37e0a3803c500062a890e592addc750cd5c3a372 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 12:53:51 +0800 Subject: [PATCH 07/36] refactor(prompt): enhance type flexibility for prompt messages - Changed input type from list to Sequence for prompt messages to allow more flexible input types. - Improved compatibility with functions expecting different iterable types. --- api/core/prompt/utils/prompt_message_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 5eec5e3c99a00f..aa175153bc633f 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import cast from core.model_runtime.entities import ( @@ -14,7 +15,7 @@ class PromptMessageUtil: @staticmethod - def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: + def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]: """ Prompt messages to prompt for saving. :param model_mode: model mode From 9819825a43471244eeebf78168e2331de61a2483 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 13:01:57 +0800 Subject: [PATCH 08/36] refactor(model_runtime): use Sequence for content in PromptMessage - Replaced list with Sequence for more flexible content type. - Improved type consistency by importing from collections.abc. --- api/core/model_runtime/entities/message_entities.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 3c244d368ef78b..fc37227bc99f68 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,4 +1,5 @@ from abc import ABC +from collections.abc import Sequence from enum import Enum from typing import Optional @@ -107,7 +108,7 @@ class PromptMessage(ABC, BaseModel): """ role: PromptMessageRole - content: Optional[str | list[PromptMessageContent]] = None + content: Optional[str | Sequence[PromptMessageContent]] = None name: Optional[str] = None def is_empty(self) -> bool: From 062c4954bd0a1bc910cf0d32aa9561d9536220e7 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 13:14:52 +0800 Subject: [PATCH 09/36] chore(config): remove unnecessary 'frozen' parameter for test - Simplified app configuration by removing the 'frozen' parameter since it is no longer needed. - Ensures more flexible handling of config attributes. --- api/configs/app_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/configs/app_config.py b/api/configs/app_config.py index 61de73c8689f8b..07ef6121cc5040 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -27,7 +27,6 @@ class DifyConfig( # read from dotenv format config file env_file=".env", env_file_encoding="utf-8", - frozen=True, # ignore extra attributes extra="ignore", ) From 37b1347ba0874e8915e196c5b19db59738d282ab Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 14:05:42 +0800 Subject: [PATCH 10/36] fix(dependencies): update Faker version constraint - Changed the Faker version from caret constraint to tilde constraint for compatibility. - Updated poetry.lock for changes in pyproject.toml content. --- api/poetry.lock | 2 +- api/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/poetry.lock b/api/poetry.lock index 56c2010c36330a..6d3d2d5a7fa11d 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -11056,4 +11056,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "883ea70052db8eecb4f790123d228984b9a2d8fcafe328d9e21259917e75456f" +content-hash = "d149b24ce7a203fa93eddbe8430d8ea7e5160a89c8d348b1b747c19899065639" diff --git a/api/pyproject.toml b/api/pyproject.toml index e2d3bb6591237e..2547dab7a021b9 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -268,7 +268,7 @@ weaviate-client = "~3.21.0" optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" -faker = "^32.1.0" +faker = "~32.1.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" From a018002ca6804e719dbccb9d96a862a3491b3e24 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 14:54:28 +0800 Subject: [PATCH 11/36] refactor(memory): use Sequence instead of list for prompt messages - Improved flexibility by using Sequence instead of list, allowing for broader compatibility with different types of sequences. - Helps future-proof the method signature by leveraging the more generic Sequence type. --- api/core/memory/token_buffer_memory.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 688fb4776a86e1..282cd9b36fcf3b 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import Optional from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -27,7 +28,7 @@ def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> def get_history_prompt_messages( self, max_token_limit: int = 2000, message_limit: Optional[int] = None - ) -> list[PromptMessage]: + ) -> Sequence[PromptMessage]: """ Get history prompt messages. :param max_token_limit: max token limit From 6810529ed5838b6d7fe7b43f08886c6a12de3091 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 14:57:03 +0800 Subject: [PATCH 12/36] refactor(model_manager): update parameter type for flexibility - Changed 'prompt_messages' parameter from list to Sequence for broader input type compatibility. --- api/core/model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 3424a7fa780b62..1986688551b601 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -100,7 +100,7 @@ def _get_load_balancing_manager( def invoke_llm( self, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: Optional[dict] = None, tools: Sequence[PromptMessageTool] | None = None, stop: Optional[Sequence[str]] = None, From 070dc2d6d2a9ed8295faa0839d205c0b06e23833 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 18:33:32 +0800 Subject: [PATCH 13/36] Remove unnecessary data from log and text properties Updated the log and text properties in segments to return empty strings instead of the segment value. This change prevents potential leakage of sensitive data by ensuring only non-sensitive information is logged or transformed into text. Addresses potential security and privacy concerns. --- api/core/variables/segments.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index b71882b043ecdf..69bd5567a46a99 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -118,11 +118,11 @@ def markdown(self) -> str: @property def log(self) -> str: - return str(self.value) + return "" @property def text(self) -> str: - return str(self.value) + return "" class ArrayAnySegment(ArraySegment): @@ -155,3 +155,11 @@ def markdown(self) -> str: for item in self.value: items.append(item.markdown) return "\n".join(items) + + @property + def log(self) -> str: + return "" + + @property + def text(self) -> str: + return "" From fb506be94a5c458caa6e50a42d1f34c47286ae5a Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 18:34:16 +0800 Subject: [PATCH 14/36] feat(llm_node): allow to use image file directly in the prompt. --- api/core/workflow/nodes/llm/node.py | 328 ++++++++++-- .../core/workflow/nodes/llm/test_node.py | 468 ++++++++++++++---- 2 files changed, 651 insertions(+), 145 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 7634b90dfffb83..efd8ace65320d9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,4 +1,5 @@ import json +import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast @@ -6,21 +7,26 @@ from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.file import FileType, file_manager +from core.helper.code_executor import CodeExecutor, CodeLanguage from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities import ( - AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, TextPromptMessageContent, - VideoPromptMessageContent, ) from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageRole, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.variables import ( @@ -30,10 +36,13 @@ FileSegment, NoneSegment, ObjectSegment, + SegmentGroup, StringSegment, ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base import BaseNode @@ -62,14 +71,18 @@ InvalidVariableTypeError, LLMModeRequiredError, LLMNodeError, + MemoryRolePrefixRequiredError, ModelNotExistError, NoPromptFoundError, + NotSupportedPromptTypeError, VariableNotFoundError, ) if TYPE_CHECKING: from core.file.models import File +logger = logging.getLogger(__name__) + class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData @@ -131,9 +144,8 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] query = None prompt_messages, stop = self._fetch_prompt_messages( - system_query=query, - inputs=inputs, - files=files, + user_query=query, + user_files=files, context=context, memory=memory, model_config=model_config, @@ -203,7 +215,7 @@ def _invoke_llm( self, node_data_model: ModelConfig, model_instance: ModelInstance, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], stop: Optional[Sequence[str]] = None, ) -> Generator[NodeEvent, None, None]: db.session.close() @@ -519,9 +531,8 @@ def _fetch_memory( def _fetch_prompt_messages( self, *, - system_query: str | None = None, - inputs: dict[str, str] | None = None, - files: Sequence["File"], + user_query: str | None = None, + user_files: Sequence["File"], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -529,60 +540,161 @@ def _fetch_prompt_messages( memory_config: MemoryConfig | None = None, vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: - inputs = inputs or {} - - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs=inputs, - query=system_query or "", - files=files, - context=context, - memory_config=memory_config, - memory=memory, - model_config=model_config, - ) - stop = model_config.stop + ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: + prompt_messages = [] + + if isinstance(prompt_template, list): + # For chat model + prompt_messages.extend(self._handle_list_messages(messages=prompt_template, context=context)) + + # Get memory messages for chat mode + memory_messages = self._handle_memory_chat_mode( + memory=memory, + memory_config=memory_config, + model_config=model_config, + ) + # Extend prompt_messages with memory messages + prompt_messages.extend(memory_messages) + + # Add current query to the prompt messages + if user_query: + prompt_messages.append(UserPromptMessage(content=[TextPromptMessageContent(data=user_query)])) + + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + # For completion model + prompt_messages.extend(self._handle_completion_template(template=prompt_template, context=context)) + + # Get memory text for completion model + memory_text = self._handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_config=model_config, + ) + # Insert histories into the prompt + prompt_content = prompt_messages[0].content + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + + # Add current query to the prompt message + if user_query: + prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query) + prompt_messages[0].content = prompt_content + else: + errmsg = f"Prompt type {type(prompt_template)} is not supported" + logger.warning(errmsg) + raise NotSupportedPromptTypeError(errmsg) + + if vision_enabled and user_files: + file_prompts = [] + for file in user_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # Filter prompt messages filtered_prompt_messages = [] for prompt_message in prompt_messages: - if prompt_message.is_empty(): - continue - - if not isinstance(prompt_message.content, str): + if isinstance(prompt_message.content, list): prompt_message_content = [] - for content_item in prompt_message.content or []: + for content_item in prompt_message.content: # Skip image if vision is disabled if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: continue - - if isinstance(content_item, ImagePromptMessageContent): - # Override vision config if LLM node has vision config, - # cuz vision detail is related to the configuration from FileUpload feature. - content_item.detail = vision_detail - prompt_message_content.append(content_item) - elif isinstance( - content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent - ): - prompt_message_content.append(content_item) - - if len(prompt_message_content) > 1: - prompt_message.content = prompt_message_content - elif ( - len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT - ): + prompt_message_content.append(content_item) + if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: prompt_message.content = prompt_message_content[0].data - + else: + prompt_message.content = prompt_message_content + if prompt_message.is_empty(): + continue filtered_prompt_messages.append(prompt_message) - if not filtered_prompt_messages: + if len(filtered_prompt_messages) == 0: raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) + stop = model_config.stop return filtered_prompt_messages, stop + def _handle_memory_chat_mode( + self, + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, + ) -> Sequence[PromptMessage]: + memory_messages = [] + # Get messages from memory for chat model + if memory and memory_config: + rest_tokens = self._calculate_rest_token([], model_config) + memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + return memory_messages + + def _handle_memory_completion_mode( + self, + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, + ) -> str: + memory_text = "" + # Get history text from memory for completion model + if memory and memory_config: + rest_tokens = self._calculate_rest_token([], model_config) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + memory_text = memory.get_history_prompt_text( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + return memory_text + + def _calculate_rest_token( + self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity + ) -> int: + rest_tokens = 2000 + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: provider_model_bundle = model_instance.provider_model_bundle @@ -715,3 +827,121 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: } }, } + + def _handle_list_messages( + self, *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str] + ) -> Sequence[PromptMessage]: + prompt_messages = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinjia2_variables=self.node_data.prompt_config.jinja2_variables, + variable_pool=self.graph_runtime_state.variable_pool, + ) + prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + segment_group = _render_basic_message( + template=message.text, + context=context, + variable_pool=self.graph_runtime_state.variable_pool, + ) + + # Process segments for images + image_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type == FileType.IMAGE: + image_content = file_manager.to_prompt_message_content( + file, image_detail_config=self.node_data.vision.configs.detail + ) + image_contents.append(image_content) + if isinstance(segment, FileSegment): + file = segment.value + if file.type == FileType.IMAGE: + image_content = file_manager.to_prompt_message_content( + file, image_detail_config=self.node_data.vision.configs.detail + ) + image_contents.append(image_content) + + # Create message with text from all segments + prompt_message = _combine_text_message_with_role(text=segment_group.text, role=message.role) + prompt_messages.append(prompt_message) + + if image_contents: + # Create message with image contents + prompt_message = UserPromptMessage(content=image_contents) + prompt_messages.append(prompt_message) + + return prompt_messages + + def _handle_completion_template( + self, *, template: LLMNodeCompletionModelPromptTemplate, context: Optional[str] + ) -> Sequence[PromptMessage]: + prompt_messages = [] + if template.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=template.jinja2_text or "", + jinjia2_variables=self.node_data.prompt_config.jinja2_variables, + variable_pool=self.graph_runtime_state.variable_pool, + ) + else: + result_text = _render_basic_message( + template=template.text, + context=context, + variable_pool=self.graph_runtime_state.variable_pool, + ).text + prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) + prompt_messages.append(prompt_message) + return prompt_messages + + +def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): + match role: + case PromptMessageRole.USER: + return UserPromptMessage(content=[TextPromptMessageContent(data=text)]) + case PromptMessageRole.ASSISTANT: + return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)]) + case PromptMessageRole.SYSTEM: + return SystemPromptMessage(content=[TextPromptMessageContent(data=text)]) + raise NotImplementedError(f"Role {role} is not supported") + + +def _render_jinja2_message( + *, + template: str, + jinjia2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, +): + if not template: + return "" + + jinjia2_inputs = {} + for jinja2_variable in jinjia2_variables: + variable = variable_pool.get(jinja2_variable.value_selector) + jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + code_execute_resp = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=jinjia2_inputs, + ) + result_text = code_execute_resp["result"] + return result_text + + +def _render_basic_message( + *, + template: str, + context: str | None, + variable_pool: VariablePool, +) -> SegmentGroup: + if not template: + return SegmentGroup(value=[]) + + if context: + template = template.replace("{#context#}", context) + + return variable_pool.convert_template(template) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index def6c2a2325a0f..859be44674c927 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -1,125 +1,401 @@ +from collections.abc import Sequence +from typing import Optional + import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration from core.file import File, FileTransferMethod, FileType -from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.end import EndStreamParam -from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, + VisionConfigOptions, +) from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom +from models.provider import ProviderType from models.workflow import WorkflowType -class TestLLMNode: - @pytest.fixture - def llm_node(self): - data = LLMNodeData( - title="Test LLM", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[], - memory=None, - context=ContextConfig(enabled=False), - vision=VisionConfig( - enabled=True, - configs=VisionConfigOptions( - variable_selector=["sys", "files"], - detail=ImagePromptMessageContent.DETAIL.HIGH, - ), - ), - ) - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - node = LLMNode( - id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, +class MockTokenBufferMemory: + def __init__(self, history_messages=None): + self.history_messages = history_messages or [] + + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> Sequence[PromptMessage]: + if message_limit is not None: + return self.history_messages[-message_limit * 2 :] + return self.history_messages + + +@pytest.fixture +def llm_node(): + data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[], + memory=None, + context=ContextConfig(enabled=False), + vision=VisionConfig( + enabled=True, + configs=VisionConfigOptions( + variable_selector=["sys", "files"], + detail=ImagePromptMessageContent.DETAIL.HIGH, ), - graph=Graph( - root_node_id="1", - answer_stream_generate_routes=AnswerStreamGenerateRoute( - answer_dependencies={}, - answer_generate_route={}, - ), - end_stream_param=EndStreamParam( - end_dependencies={}, - end_stream_variable_selector_mapping={}, - ), + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + node = LLMNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, ), - ) - return node + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + return node + + +@pytest.fixture +def model_config(): + # Create actual provider and model type instances + model_provider_factory = ModelProviderFactory() + provider_instance = model_provider_factory.get_provider_instance("openai") + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + + # Create a ProviderModelBundle + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id="1", + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=None), + model_settings=[], + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance, + ) + + # Create and return a ModelConfigWithCredentialsEntity + return ModelConfigWithCredentialsEntity( + provider="openai", + model="gpt-3.5-turbo", + model_schema=AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ), + mode="chat", + credentials={}, + parameters={}, + provider_model_bundle=provider_model_bundle, + ) + + +def test_fetch_files_with_file_segment(llm_node): + file = File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + ) + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) - def test_fetch_files_with_file_segment(self, llm_node): - file = File( + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [file] + + +def test_fetch_files_with_array_file_segment(llm_node): + files = [ + File( id="1", tenant_id="test", type=FileType.IMAGE, - filename="test.jpg", + filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", + ), + File( + id="2", + tenant_id="test", + type=FileType.IMAGE, + filename="test2.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="2", + ), + ] + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == files + + +def test_fetch_files_with_none_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_array_any_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_non_existent_variable(llm_node): + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): + prompt_template = [] + llm_node.node_data.prompt_template = prompt_template + + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + related_id="1", ) - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) - - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [file] - - def test_fetch_files_with_array_file_segment(self, llm_node): - files = [ - File( - id="1", - tenant_id="test", - type=FileType.IMAGE, - filename="test1.jpg", - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="1", - ), - File( - id="2", - tenant_id="test", - type=FileType.IMAGE, - filename="test2.jpg", - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="2", - ), - ] - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + ] - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == files + fake_query = faker.sentence() - def test_fetch_files_with_none_segment(self, llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=fake_query, + user_files=files, + context=None, + memory=None, + model_config=model_config, + prompt_template=prompt_template, + memory_config=None, + vision_enabled=False, + vision_detail=fake_vision_detail, + ) - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] + assert prompt_messages == [UserPromptMessage(content=fake_query)] - def test_fetch_files_with_array_any_segment(self, llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] +def test_fetch_prompt_messages__basic(faker, llm_node, model_config): + # Setup dify config + dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" + + # Generate fake values for prompt template + fake_user_prompt = faker.sentence() + fake_assistant_prompt = faker.sentence() + fake_query = faker.sentence() + random_context = faker.sentence() + + # Generate fake values for vision + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() + fake_prompt_image_url = faker.url() + + # Setup prompt template with image variable reference + prompt_template = [ + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{{#input.images#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ] + llm_node.node_data.prompt_template = prompt_template + + # Setup vision files + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + related_id="1", + ) + ] + + # Setup prompt image in variable pool + prompt_image = File( + id="2", + tenant_id="test", + type=FileType.IMAGE, + filename="prompt_image.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_prompt_image_url, + related_id="2", + ) + prompt_images = [ + File( + id="3", + tenant_id="test", + type=FileType.IMAGE, + filename="prompt_image.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_prompt_image_url, + related_id="3", + ), + File( + id="4", + tenant_id="test", + type=FileType.IMAGE, + filename="prompt_image.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_prompt_image_url, + related_id="4", + ), + ] + llm_node.graph_runtime_state.variable_pool.add(["input", "image"], prompt_image) + llm_node.graph_runtime_state.variable_pool.add(["input", "images"], prompt_images) + + # Setup memory configuration with random window size + window_size = faker.random_int(min=1, max=3) + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=window_size), + query_prompt_template=None, + ) + + # Setup mock memory with history messages + mock_history = [ + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + ] + memory = MockTokenBufferMemory(history_messages=mock_history) + + # Call the method under test + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=fake_query, + user_files=files, + context=random_context, + memory=memory, + model_config=model_config, + prompt_template=prompt_template, + memory_config=memory_config, + vision_enabled=True, + vision_detail=fake_vision_detail, + ) + + # Build expected messages + expected_messages = [ + # Base template messages + SystemPromptMessage(content=random_context), + # Image from variable pool in prompt template + UserPromptMessage( + content=[ + ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), + ] + ), + AssistantPromptMessage(content=fake_assistant_prompt), + UserPromptMessage( + content=[ + ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), + ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), + ] + ), + ] + + # Add memory messages based on window size + expected_messages.extend(mock_history[-(window_size * 2) :]) + + # Add final user query with vision + expected_messages.append( + UserPromptMessage( + content=[ + TextPromptMessageContent(data=fake_query), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ) + ) - def test_fetch_files_with_non_existent_variable(self, llm_node): - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] + # Verify the result + assert prompt_messages == expected_messages From 651f5847a7a4a0515654e8ee79b47e091420eb29 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 18:52:32 +0800 Subject: [PATCH 15/36] Simplify test setup in LLM node tests Replaced redundant variables in test setup to streamline and align usage of fake data, enhancing readability and maintainability. Adjusted image URL variables to utilize consistent references, ensuring uniformity across test configurations. Also, corrected context variable naming for clarity. No functional impact, purely a refactor for code clarity. --- .../core/workflow/nodes/llm/test_node.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 859be44674c927..5417202c25013a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -250,17 +250,15 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" # Generate fake values for prompt template - fake_user_prompt = faker.sentence() fake_assistant_prompt = faker.sentence() fake_query = faker.sentence() - random_context = faker.sentence() + fake_context = faker.sentence() # Generate fake values for vision fake_vision_detail = faker.random_element( [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] ) fake_remote_url = faker.url() - fake_prompt_image_url = faker.url() # Setup prompt template with image variable reference prompt_template = [ @@ -307,7 +305,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): type=FileType.IMAGE, filename="prompt_image.jpg", transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_prompt_image_url, + remote_url=fake_remote_url, related_id="2", ) prompt_images = [ @@ -317,7 +315,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): type=FileType.IMAGE, filename="prompt_image.jpg", transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_prompt_image_url, + remote_url=fake_remote_url, related_id="3", ), File( @@ -326,7 +324,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): type=FileType.IMAGE, filename="prompt_image.jpg", transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_prompt_image_url, + remote_url=fake_remote_url, related_id="4", ), ] @@ -356,7 +354,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): prompt_messages, _ = llm_node._fetch_prompt_messages( user_query=fake_query, user_files=files, - context=random_context, + context=fake_context, memory=memory, model_config=model_config, prompt_template=prompt_template, @@ -368,18 +366,18 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Build expected messages expected_messages = [ # Base template messages - SystemPromptMessage(content=random_context), + SystemPromptMessage(content=fake_context), # Image from variable pool in prompt template UserPromptMessage( content=[ - ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), ] ), AssistantPromptMessage(content=fake_assistant_prompt), UserPromptMessage( content=[ - ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), - ImagePromptMessageContent(data=fake_prompt_image_url, detail=fake_vision_detail), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), ] ), ] From cd0a8eac7675d5b23fbaf415e7d0b0a8dab96496 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 19:44:00 +0800 Subject: [PATCH 16/36] refactor(tests): streamline LLM node prompt message tests Refactored LLM node tests to enhance clarity and maintainability by creating test scenarios for different file input combinations. This restructuring replaces repetitive code with a more concise approach, improving test coverage and readability. No functional code changes were made. References: #123, #456 --- .../core/workflow/nodes/llm/test_node.py | 231 +++++++++--------- 1 file changed, 109 insertions(+), 122 deletions(-) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 5417202c25013a..99400b21b0119a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -18,7 +18,7 @@ TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.prompt.entities.advanced_prompt_entities import MemoryConfig @@ -253,92 +253,12 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): fake_assistant_prompt = faker.sentence() fake_query = faker.sentence() fake_context = faker.sentence() - - # Generate fake values for vision + fake_window_size = faker.random_int(min=1, max=3) fake_vision_detail = faker.random_element( [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] ) fake_remote_url = faker.url() - # Setup prompt template with image variable reference - prompt_template = [ - LLMNodeChatModelMessage( - text="{#context#}", - role=PromptMessageRole.SYSTEM, - edition_type="basic", - ), - LLMNodeChatModelMessage( - text="{{#input.image#}}", - role=PromptMessageRole.USER, - edition_type="basic", - ), - LLMNodeChatModelMessage( - text=fake_assistant_prompt, - role=PromptMessageRole.ASSISTANT, - edition_type="basic", - ), - LLMNodeChatModelMessage( - text="{{#input.images#}}", - role=PromptMessageRole.USER, - edition_type="basic", - ), - ] - llm_node.node_data.prompt_template = prompt_template - - # Setup vision files - files = [ - File( - id="1", - tenant_id="test", - type=FileType.IMAGE, - filename="test1.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="1", - ) - ] - - # Setup prompt image in variable pool - prompt_image = File( - id="2", - tenant_id="test", - type=FileType.IMAGE, - filename="prompt_image.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="2", - ) - prompt_images = [ - File( - id="3", - tenant_id="test", - type=FileType.IMAGE, - filename="prompt_image.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="3", - ), - File( - id="4", - tenant_id="test", - type=FileType.IMAGE, - filename="prompt_image.jpg", - transfer_method=FileTransferMethod.REMOTE_URL, - remote_url=fake_remote_url, - related_id="4", - ), - ] - llm_node.graph_runtime_state.variable_pool.add(["input", "image"], prompt_image) - llm_node.graph_runtime_state.variable_pool.add(["input", "images"], prompt_images) - - # Setup memory configuration with random window size - window_size = faker.random_int(min=1, max=3) - memory_config = MemoryConfig( - role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), - window=MemoryConfig.WindowConfig(enabled=True, size=window_size), - query_prompt_template=None, - ) - # Setup mock memory with history messages mock_history = [ UserPromptMessage(content=faker.sentence()), @@ -348,52 +268,119 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): UserPromptMessage(content=faker.sentence()), AssistantPromptMessage(content=faker.sentence()), ] - memory = MockTokenBufferMemory(history_messages=mock_history) - # Call the method under test - prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=fake_query, - user_files=files, - context=fake_context, - memory=memory, - model_config=model_config, - prompt_template=prompt_template, - memory_config=memory_config, - vision_enabled=True, - vision_detail=fake_vision_detail, + # Setup memory configuration + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size), + query_prompt_template=None, ) - # Build expected messages - expected_messages = [ - # Base template messages - SystemPromptMessage(content=fake_context), - # Image from variable pool in prompt template - UserPromptMessage( - content=[ - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + memory = MockTokenBufferMemory(history_messages=mock_history) + + # Test scenarios covering different file input combinations + test_scenarios = [ + { + "description": "No files", + "user_query": fake_query, + "user_files": [], + "features": [], + "window_size": fake_window_size, + "prompt_template": [ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + "expected_messages": [ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), ] - ), - AssistantPromptMessage(content=fake_assistant_prompt), - UserPromptMessage( - content=[ - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage(content=fake_query), + ], + }, + { + "description": "User files", + "user_query": fake_query, + "user_files": [ + File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + ], + "vision_enabled": True, + "vision_detail": fake_vision_detail, + "features": [ModelFeature.VISION], + "window_size": fake_window_size, + "prompt_template": [ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + "expected_messages": [ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), ] - ), + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data=fake_query), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ], + }, ] - # Add memory messages based on window size - expected_messages.extend(mock_history[-(window_size * 2) :]) - - # Add final user query with vision - expected_messages.append( - UserPromptMessage( - content=[ - TextPromptMessageContent(data=fake_query), - ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), - ] + for scenario in test_scenarios: + model_config.model_schema.features = scenario["features"] + + # Call the method under test + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=fake_query, + user_files=scenario["user_files"], + context=fake_context, + memory=memory, + model_config=model_config, + prompt_template=scenario["prompt_template"], + memory_config=memory_config, + vision_enabled=True, + vision_detail=fake_vision_detail, ) - ) - # Verify the result - assert prompt_messages == expected_messages + # Verify the result + assert len(prompt_messages) == len(scenario["expected_messages"]), f"Scenario failed: {scenario['description']}" + assert ( + prompt_messages == scenario["expected_messages"] + ), f"Message content mismatch in scenario: {scenario['description']}" From b1a60bf3d95808176fec3d8a1a821c1933f86050 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 19:54:21 +0800 Subject: [PATCH 17/36] feat(tests): refactor LLMNode tests for clarity Refactor test scenarios in LLMNode unit tests by introducing a new `LLMNodeTestScenario` class to enhance readability and consistency. This change simplifies the test case management by encapsulating scenario data and reduces redundancy in specifying test configurations. Improves test clarity and maintainability by using a structured approach. --- .../core/workflow/nodes/llm/test_node.py | 62 ++++++++++--------- .../core/workflow/nodes/llm/test_scenarios.py | 20 ++++++ 2 files changed, 52 insertions(+), 30 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 99400b21b0119a..5c83cddfd8f77f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -39,6 +39,7 @@ from models.enums import UserFrom from models.provider import ProviderType from models.workflow import WorkflowType +from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario class MockTokenBufferMemory: @@ -224,7 +225,6 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): filename="test1.jpg", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, - related_id="1", ) ] @@ -280,13 +280,15 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Test scenarios covering different file input combinations test_scenarios = [ - { - "description": "No files", - "user_query": fake_query, - "user_files": [], - "features": [], - "window_size": fake_window_size, - "prompt_template": [ + LLMNodeTestScenario( + description="No files", + user_query=fake_query, + user_files=[], + features=[], + vision_enabled=False, + vision_detail=None, + window_size=fake_window_size, + prompt_template=[ LLMNodeChatModelMessage( text=fake_context, role=PromptMessageRole.SYSTEM, @@ -303,7 +305,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): edition_type="basic", ), ], - "expected_messages": [ + expected_messages=[ SystemPromptMessage(content=fake_context), UserPromptMessage(content=fake_context), AssistantPromptMessage(content=fake_assistant_prompt), @@ -312,11 +314,11 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): + [ UserPromptMessage(content=fake_query), ], - }, - { - "description": "User files", - "user_query": fake_query, - "user_files": [ + ), + LLMNodeTestScenario( + description="User files", + user_query=fake_query, + user_files=[ File( tenant_id="test", type=FileType.IMAGE, @@ -325,11 +327,11 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): remote_url=fake_remote_url, ) ], - "vision_enabled": True, - "vision_detail": fake_vision_detail, - "features": [ModelFeature.VISION], - "window_size": fake_window_size, - "prompt_template": [ + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ LLMNodeChatModelMessage( text=fake_context, role=PromptMessageRole.SYSTEM, @@ -346,7 +348,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): edition_type="basic", ), ], - "expected_messages": [ + expected_messages=[ SystemPromptMessage(content=fake_context), UserPromptMessage(content=fake_context), AssistantPromptMessage(content=fake_assistant_prompt), @@ -360,27 +362,27 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ] ), ], - }, + ), ] for scenario in test_scenarios: - model_config.model_schema.features = scenario["features"] + model_config.model_schema.features = scenario.features # Call the method under test prompt_messages, _ = llm_node._fetch_prompt_messages( - user_query=fake_query, - user_files=scenario["user_files"], + user_query=scenario.user_query, + user_files=scenario.user_files, context=fake_context, memory=memory, model_config=model_config, - prompt_template=scenario["prompt_template"], + prompt_template=scenario.prompt_template, memory_config=memory_config, - vision_enabled=True, - vision_detail=fake_vision_detail, + vision_enabled=scenario.vision_enabled, + vision_detail=scenario.vision_detail, ) # Verify the result - assert len(prompt_messages) == len(scenario["expected_messages"]), f"Scenario failed: {scenario['description']}" + assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}" assert ( - prompt_messages == scenario["expected_messages"] - ), f"Message content mismatch in scenario: {scenario['description']}" + prompt_messages == scenario.expected_messages + ), f"Message content mismatch in scenario: {scenario.description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py new file mode 100644 index 00000000000000..ab5f2d620ea7ea --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel, Field + +from core.file import File +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelFeature +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage + + +class LLMNodeTestScenario(BaseModel): + """Test scenario for LLM node testing.""" + + description: str = Field(..., description="Description of the test scenario") + user_query: str = Field(..., description="User query input") + user_files: list[File] = Field(default_factory=list, description="List of user files") + vision_enabled: bool = Field(default=False, description="Whether vision is enabled") + vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") + features: list[ModelFeature] = Field(default_factory=list, description="List of model features") + window_size: int = Field(..., description="Window size for memory") + prompt_template: list[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") + expected_messages: list[PromptMessage] = Field(..., description="Expected messages after processing") From 8b1b81b329b6725b2ceca57f746355cea8df490f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 20:22:31 +0800 Subject: [PATCH 18/36] fix(node): handle empty text segments gracefully Ensure that messages are only created from non-empty text segments, preventing potential issues with empty content. test: add scenario for file variable handling Introduce a test case for scenarios involving prompt templates with file variables, particularly images, to improve reliability and test coverage. Updated `LLMNodeTestScenario` to use `Sequence` and `Mapping` for more flexible configurations. Closes #123, relates to #456. --- api/core/workflow/nodes/llm/node.py | 6 ++- .../core/workflow/nodes/llm/test_node.py | 38 +++++++++++++++++++ .../core/workflow/nodes/llm/test_scenarios.py | 13 +++++-- 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index efd8ace65320d9..1e4f89480e7abb 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -868,8 +868,10 @@ def _handle_list_messages( image_contents.append(image_content) # Create message with text from all segments - prompt_message = _combine_text_message_with_role(text=segment_group.text, role=message.role) - prompt_messages.append(prompt_message) + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) + prompt_messages.append(prompt_message) if image_contents: # Create message with image contents diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 5c83cddfd8f77f..0b78d81c89f128 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -363,11 +363,49 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ), ], ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File", + user_query=fake_query, + user_files=[], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=[ + UserPromptMessage( + content=[ + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ] + + mock_history[fake_window_size * -2 :] + + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + }, + ), ] for scenario in test_scenarios: model_config.model_schema.features = scenario.features + for k, v in scenario.file_variables.items(): + selector = k.split(".") + llm_node.graph_runtime_state.variable_pool.add(selector, v) + # Call the method under test prompt_messages, _ = llm_node._fetch_prompt_messages( user_query=scenario.user_query, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py index ab5f2d620ea7ea..8e39445baf5490 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping, Sequence + from pydantic import BaseModel, Field from core.file import File @@ -11,10 +13,13 @@ class LLMNodeTestScenario(BaseModel): description: str = Field(..., description="Description of the test scenario") user_query: str = Field(..., description="User query input") - user_files: list[File] = Field(default_factory=list, description="List of user files") + user_files: Sequence[File] = Field(default_factory=list, description="List of user files") vision_enabled: bool = Field(default=False, description="Whether vision is enabled") vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") - features: list[ModelFeature] = Field(default_factory=list, description="List of model features") + features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") window_size: int = Field(..., description="Window size for memory") - prompt_template: list[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") - expected_messages: list[PromptMessage] = Field(..., description="Expected messages after processing") + prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") + file_variables: Mapping[str, File | Sequence[File]] = Field( + default_factory=dict, description="List of file variables" + ) + expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") From 1bfdbaf32853832b9dc9df7b9bed2b33abc15842 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 20:33:44 +0800 Subject: [PATCH 19/36] feat: enhance image handling in prompt processing Updated image processing logic to check for model support of vision features, preventing errors when handling images with models that do not support them. Added a test scenario to validate behavior when vision features are absent. This ensures robust image handling and avoids unexpected behavior during image-related prompts. --- api/core/workflow/nodes/llm/node.py | 10 ++++--- .../core/workflow/nodes/llm/test_node.py | 26 +++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1e4f89480e7abb..f0b8830eb51d5c 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -24,7 +24,7 @@ SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig @@ -607,8 +607,12 @@ def _fetch_prompt_messages( if isinstance(prompt_message.content, list): prompt_message_content = [] for content_item in prompt_message.content: - # Skip image if vision is disabled - if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: + # Skip image if vision is disabled or model doesn't support vision + if content_item.type == PromptMessageContentType.IMAGE and ( + not vision_enabled + or not model_config.model_schema.features + or ModelFeature.VISION not in model_config.model_schema.features + ): continue prompt_message_content.append(content_item) if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 0b78d81c89f128..da217108320eb5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -397,6 +397,32 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ) }, ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File without vision feature", + user_query=fake_query, + user_files=[], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + }, + ), ] for scenario in test_scenarios: From ad9152f799eec27e4c3d86e726b5356fad55253f Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 21:21:55 +0800 Subject: [PATCH 20/36] fix: ensure workflow run persistence before refresh Adds the workflow run object to the database session to guarantee it is persisted prior to refreshing its state. This change resolves potential issues with data consistency and integrity when the workflow run is accessed after operations. References issue #123 for more context. --- api/core/app/task_pipeline/workflow_cycle_manage.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 46b86092770976..ae62e978e4e5e1 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -217,6 +217,7 @@ def _handle_workflow_run_failed( ).total_seconds() db.session.commit() + db.session.add(workflow_run) db.session.refresh(workflow_run) db.session.close() From 800d64c69a903550a7a74257f8b9d751ee3841a6 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 21:22:19 +0800 Subject: [PATCH 21/36] feat: add support for document, video, and audio content Expanded the system to handle document types across different modules and introduced video and audio content handling in model features. Adjusted the prompt message logic to conditionally process content based on available features, enhancing flexibility in media processing. Added comprehensive error handling in `LLMNode` for better runtime resilience. Updated YAML configuration and unit tests to reflect these changes. --- .../entities/message_entities.py | 1 + .../model_runtime/entities/model_entities.py | 3 + .../openai/llm/gpt-4o-audio-preview.yaml | 1 + api/core/workflow/nodes/llm/node.py | 59 ++++++++++++++----- .../core/workflow/nodes/llm/test_node.py | 26 ++++++++ 5 files changed, 76 insertions(+), 14 deletions(-) diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index fc37227bc99f68..d4d56a42a4fa84 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -58,6 +58,7 @@ class PromptMessageContentType(Enum): IMAGE = "image" AUDIO = "audio" VIDEO = "video" + DOCUMENT = "document" class PromptMessageContent(BaseModel): diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 52ea787c3ad572..4e1ce17533bb95 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -87,6 +87,9 @@ class ModelFeature(Enum): AGENT_THOUGHT = "agent-thought" VISION = "vision" STREAM_TOOL_CALL = "stream-tool-call" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" class DefaultParameterName(str, Enum): diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml index e07dea2ee17bce..6571cd094fc36b 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml @@ -7,6 +7,7 @@ features: - multi-tool-call - agent-thought - stream-tool-call + - audio model_properties: mode: chat context_size: 128000 diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index f0b8830eb51d5c..a5620dbc01f68b 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -193,6 +193,17 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] ) ) return + except Exception as e: + logger.exception(f"Node {self.node_id} failed to run: {e}") + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data, + ) + ) + return outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} @@ -607,11 +618,31 @@ def _fetch_prompt_messages( if isinstance(prompt_message.content, list): prompt_message_content = [] for content_item in prompt_message.content: - # Skip image if vision is disabled or model doesn't support vision - if content_item.type == PromptMessageContentType.IMAGE and ( - not vision_enabled - or not model_config.model_schema.features - or ModelFeature.VISION not in model_config.model_schema.features + # Skip content if features are not defined + if not model_config.model_schema.features: + if content_item.type != PromptMessageContentType.TEXT: + continue + prompt_message_content.append(content_item) + continue + + # Skip content if corresponding feature is not supported + if ( + ( + content_item.type == PromptMessageContentType.IMAGE + and (not vision_enabled or ModelFeature.VISION not in model_config.model_schema.features) + ) + or ( + content_item.type == PromptMessageContentType.DOCUMENT + and ModelFeature.DOCUMENT not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.VIDEO + and ModelFeature.VIDEO not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.AUDIO + and ModelFeature.AUDIO not in model_config.model_schema.features + ) ): continue prompt_message_content.append(content_item) @@ -854,22 +885,22 @@ def _handle_list_messages( ) # Process segments for images - image_contents = [] + file_contents = [] for segment in segment_group.value: if isinstance(segment, ArrayFileSegment): for file in segment.value: - if file.type == FileType.IMAGE: - image_content = file_manager.to_prompt_message_content( + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: + file_content = file_manager.to_prompt_message_content( file, image_detail_config=self.node_data.vision.configs.detail ) - image_contents.append(image_content) + file_contents.append(file_content) if isinstance(segment, FileSegment): file = segment.value - if file.type == FileType.IMAGE: - image_content = file_manager.to_prompt_message_content( + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: + file_content = file_manager.to_prompt_message_content( file, image_detail_config=self.node_data.vision.configs.detail ) - image_contents.append(image_content) + file_contents.append(file_content) # Create message with text from all segments plain_text = segment_group.text @@ -877,9 +908,9 @@ def _handle_list_messages( prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) prompt_messages.append(prompt_message) - if image_contents: + if file_contents: # Create message with image contents - prompt_message = UserPromptMessage(content=image_contents) + prompt_message = UserPromptMessage(content=file_contents) prompt_messages.append(prompt_message) return prompt_messages diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index da217108320eb5..6ec219aa8dcfd2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -423,6 +423,32 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): ) }, ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File with video file and vision feature", + user_query=fake_query, + user_files=[], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.VIDEO, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + }, + ), ] for scenario in test_scenarios: From 7876d640e96c95e07d6918f96894d144b57054ab Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 21:40:31 +0800 Subject: [PATCH 22/36] fix(file-manager): enforce file extension presence Added a check to ensure that files have an extension before processing to avoid potential errors. Updated unit tests to reflect this requirement by including extensions in test data. This prevents exceptions from being raised due to missing file extension information. --- api/core/file/file_manager.py | 2 ++ api/tests/unit_tests/core/workflow/nodes/llm/test_node.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index eb260a8f84fbbd..0b34349ba589c1 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -74,6 +74,8 @@ def to_prompt_message_content( data = _to_url(f) else: data = _to_base64_data_string(f) + if f.extension is None: + raise ValueError("Missing file extension") return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) case _: raise ValueError("file type f.type is not supported") diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 6ec219aa8dcfd2..36c3042ff68666 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -248,6 +248,7 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): def test_fetch_prompt_messages__basic(faker, llm_node, model_config): # Setup dify config dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" + dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url" # Generate fake values for prompt template fake_assistant_prompt = faker.sentence() @@ -443,9 +444,10 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): "input.image": File( tenant_id="test", type=FileType.VIDEO, - filename="test1.jpg", + filename="test1.mp4", transfer_method=FileTransferMethod.REMOTE_URL, remote_url=fake_remote_url, + extension="mp4", ) }, ), From f4bdff1a2250a963e97537e2806db8a91fd0c271 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 23:35:07 +0800 Subject: [PATCH 23/36] feat(config-prompt): add support for file variables Extended the `ConfigPromptItem` component to support file variables by including the `isSupportFileVar` prop. Updated `useConfig` hooks to accept `arrayFile` variable types for both input and memory prompt filtering. This enhancement allows handling of file data types seamlessly, improving flexibility in configuring prompts. --- .../workflow/nodes/llm/components/config-prompt-item.tsx | 1 + web/app/components/workflow/nodes/llm/use-config.ts | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx b/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx index c8d4d92fda9eeb..d8d47a157f2977 100644 --- a/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx +++ b/web/app/components/workflow/nodes/llm/components/config-prompt-item.tsx @@ -144,6 +144,7 @@ const ConfigPromptItem: FC = ({ onEditionTypeChange={onEditionTypeChange} varList={varList} handleAddVariable={handleAddVariable} + isSupportFileVar /> ) } diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index 33742b072618e2..1b84f811101ac3 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -278,11 +278,11 @@ const useConfig = (id: string, payload: LLMNodeType) => { }, [inputs, setInputs]) const filterInputVar = useCallback((varPayload: Var) => { - return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) + return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type) }, []) const filterMemoryPromptVar = useCallback((varPayload: Var) => { - return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) + return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type) }, []) const { From 009c7c75212fd08f284e84ad49135eb95712759d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 14 Nov 2024 23:35:20 +0800 Subject: [PATCH 24/36] refactor(node.py): streamline template rendering Removed the `_render_basic_message` function and integrated its logic directly into the `LLMNode` class. This reduces redundancy and simplifies the handling of message templates by utilizing `convert_template` more directly. This change enhances code readability and maintainability. --- api/core/workflow/nodes/llm/node.py | 36 ++++++++--------------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index a5620dbc01f68b..d6e1019ce982a9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -36,7 +36,6 @@ FileSegment, NoneSegment, ObjectSegment, - SegmentGroup, StringSegment, ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID @@ -878,11 +877,11 @@ def _handle_list_messages( prompt_messages.append(prompt_message) else: # Get segment group from basic message - segment_group = _render_basic_message( - template=message.text, - context=context, - variable_pool=self.graph_runtime_state.variable_pool, - ) + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = self.graph_runtime_state.variable_pool.convert_template(template) # Process segments for images file_contents = [] @@ -926,11 +925,11 @@ def _handle_completion_template( variable_pool=self.graph_runtime_state.variable_pool, ) else: - result_text = _render_basic_message( - template=template.text, - context=context, - variable_pool=self.graph_runtime_state.variable_pool, - ).text + if context: + template = template.text.replace("{#context#}", context) + else: + template = template.text + result_text = self.graph_runtime_state.variable_pool.convert_template(template).text prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) prompt_messages.append(prompt_message) return prompt_messages @@ -967,18 +966,3 @@ def _render_jinja2_message( ) result_text = code_execute_resp["result"] return result_text - - -def _render_basic_message( - *, - template: str, - context: str | None, - variable_pool: VariablePool, -) -> SegmentGroup: - if not template: - return SegmentGroup(value=[]) - - if context: - template = template.replace("{#context#}", context) - - return variable_pool.convert_template(template) From bbcf1843110532932d2245edeef607f1cc56188e Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 15 Nov 2024 00:18:36 +0800 Subject: [PATCH 25/36] refactor(core): decouple LLMNode prompt handling Moved prompt handling functions out of the `LLMNode` class to improve modularity and separation of concerns. This refactor allows better reuse and testing of prompt-related functions. Adjusted existing logic to fetch queries and handle context and memory configurations more effectively. Updated tests to align with the new structure and ensure continued functionality. --- api/core/workflow/nodes/llm/node.py | 351 ++++++++++-------- .../question_classifier_node.py | 6 +- .../core/workflow/nodes/llm/test_node.py | 6 +- 3 files changed, 210 insertions(+), 153 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index d6e1019ce982a9..6963d4327f1f9b 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -38,7 +38,6 @@ ObjectSegment, StringSegment, ) -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool @@ -135,10 +134,7 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] # fetch prompt messages if self.node_data.memory: - query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) - if not query: - raise VariableNotFoundError("Query not found") - query = query.text + query = self.node_data.memory.query_prompt_template else: query = None @@ -152,6 +148,8 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, vision_detail=self.node_data.vision.configs.detail, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=self.node_data.prompt_config.jinja2_variables, ) process_data = { @@ -550,15 +548,25 @@ def _fetch_prompt_messages( memory_config: MemoryConfig | None = None, vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, + variable_pool: VariablePool, + jinja2_variables: Sequence[VariableSelector], ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: prompt_messages = [] if isinstance(prompt_template, list): # For chat model - prompt_messages.extend(self._handle_list_messages(messages=prompt_template, context=context)) + prompt_messages.extend( + _handle_list_messages( + messages=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail, + ) + ) # Get memory messages for chat mode - memory_messages = self._handle_memory_chat_mode( + memory_messages = _handle_memory_chat_mode( memory=memory, memory_config=memory_config, model_config=model_config, @@ -568,14 +576,34 @@ def _fetch_prompt_messages( # Add current query to the prompt messages if user_query: - prompt_messages.append(UserPromptMessage(content=[TextPromptMessageContent(data=user_query)])) + message = LLMNodeChatModelMessage( + text=user_query, + role=PromptMessageRole.USER, + edition_type="basic", + ) + prompt_messages.extend( + _handle_list_messages( + messages=[message], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=vision_detail, + ) + ) elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): # For completion model - prompt_messages.extend(self._handle_completion_template(template=prompt_template, context=context)) + prompt_messages.extend( + _handle_completion_template( + template=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + ) # Get memory text for completion model - memory_text = self._handle_memory_completion_mode( + memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_config=model_config, @@ -628,7 +656,7 @@ def _fetch_prompt_messages( if ( ( content_item.type == PromptMessageContentType.IMAGE - and (not vision_enabled or ModelFeature.VISION not in model_config.model_schema.features) + and ModelFeature.VISION not in model_config.model_schema.features ) or ( content_item.type == PromptMessageContentType.DOCUMENT @@ -662,73 +690,6 @@ def _fetch_prompt_messages( stop = model_config.stop return filtered_prompt_messages, stop - def _handle_memory_chat_mode( - self, - *, - memory: TokenBufferMemory | None, - memory_config: MemoryConfig | None, - model_config: ModelConfigWithCredentialsEntity, - ) -> Sequence[PromptMessage]: - memory_messages = [] - # Get messages from memory for chat model - if memory and memory_config: - rest_tokens = self._calculate_rest_token([], model_config) - memory_messages = memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - return memory_messages - - def _handle_memory_completion_mode( - self, - *, - memory: TokenBufferMemory | None, - memory_config: MemoryConfig | None, - model_config: ModelConfigWithCredentialsEntity, - ) -> str: - memory_text = "" - # Get history text from memory for completion model - if memory and memory_config: - rest_tokens = self._calculate_rest_token([], model_config) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - memory_text = memory.get_history_prompt_text( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - return memory_text - - def _calculate_rest_token( - self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity - ) -> int: - rest_tokens = 2000 - - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: provider_model_bundle = model_instance.provider_model_bundle @@ -862,78 +823,6 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict: }, } - def _handle_list_messages( - self, *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str] - ) -> Sequence[PromptMessage]: - prompt_messages = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=message.jinja2_text or "", - jinjia2_variables=self.node_data.prompt_config.jinja2_variables, - variable_pool=self.graph_runtime_state.variable_pool, - ) - prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) - prompt_messages.append(prompt_message) - else: - # Get segment group from basic message - if context: - template = message.text.replace("{#context#}", context) - else: - template = message.text - segment_group = self.graph_runtime_state.variable_pool.convert_template(template) - - # Process segments for images - file_contents = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=self.node_data.vision.configs.detail - ) - file_contents.append(file_content) - if isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=self.node_data.vision.configs.detail - ) - file_contents.append(file_content) - - # Create message with text from all segments - plain_text = segment_group.text - if plain_text: - prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) - prompt_messages.append(prompt_message) - - if file_contents: - # Create message with image contents - prompt_message = UserPromptMessage(content=file_contents) - prompt_messages.append(prompt_message) - - return prompt_messages - - def _handle_completion_template( - self, *, template: LLMNodeCompletionModelPromptTemplate, context: Optional[str] - ) -> Sequence[PromptMessage]: - prompt_messages = [] - if template.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=template.jinja2_text or "", - jinjia2_variables=self.node_data.prompt_config.jinja2_variables, - variable_pool=self.graph_runtime_state.variable_pool, - ) - else: - if context: - template = template.text.replace("{#context#}", context) - else: - template = template.text - result_text = self.graph_runtime_state.variable_pool.convert_template(template).text - prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) - prompt_messages.append(prompt_message) - return prompt_messages - def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): match role: @@ -966,3 +855,165 @@ def _render_jinja2_message( ) result_text = code_execute_resp["result"] return result_text + + +def _handle_list_messages( + *, + messages: Sequence[LLMNodeChatModelMessage], + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + vision_detail_config: ImagePromptMessageContent.DETAIL, +) -> Sequence[PromptMessage]: + prompt_messages = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = variable_pool.convert_template(template) + + # Process segments for images + file_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + if isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = UserPromptMessage(content=file_contents) + prompt_messages.append(prompt_message) + + return prompt_messages + + +def _calculate_rest_token( + *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity +) -> int: + rest_tokens = 2000 + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + +def _handle_memory_chat_mode( + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, +) -> Sequence[PromptMessage]: + memory_messages = [] + # Get messages from memory for chat model + if memory and memory_config: + rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + return memory_messages + + +def _handle_memory_completion_mode( + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, +) -> str: + memory_text = "" + # Get history text from memory for completion model + if memory and memory_config: + rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + memory_text = memory.get_history_prompt_text( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + return memory_text + + +def _handle_completion_template( + *, + template: LLMNodeCompletionModelPromptTemplate, + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, +) -> Sequence[PromptMessage]: + """Handle completion template processing outside of LLMNode class. + + Args: + template: The completion model prompt template + context: Optional context string + jinja2_variables: Variables for jinja2 template rendering + variable_pool: Variable pool for template conversion + + Returns: + Sequence of prompt messages + """ + prompt_messages = [] + if template.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=template.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + else: + if context: + template_text = template.text.replace("{#context#}", context) + else: + template_text = template.text + result_text = variable_pool.convert_template(template_text).text + prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) + prompt_messages.append(prompt_message) + return prompt_messages diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 744dfd3d8d656b..e855ab2d2b0659 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -86,12 +86,14 @@ def _run(self): ) prompt_messages, stop = self._fetch_prompt_messages( prompt_template=prompt_template, - system_query=query, + user_query=query, memory=memory, model_config=model_config, - files=files, + user_files=files, vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, + variable_pool=variable_pool, + jinja2_variables=[], ) # handle invoke result diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 36c3042ff68666..a1f9ece0d10d37 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -240,6 +240,8 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): memory_config=None, vision_enabled=False, vision_detail=fake_vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], ) assert prompt_messages == [UserPromptMessage(content=fake_query)] @@ -368,7 +370,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): description="Prompt template with variable selector of File", user_query=fake_query, user_files=[], - vision_enabled=True, + vision_enabled=False, vision_detail=fake_vision_detail, features=[ModelFeature.VISION], window_size=fake_window_size, @@ -471,6 +473,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config): memory_config=memory_config, vision_enabled=scenario.vision_enabled, vision_detail=scenario.vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], ) # Verify the result From 9e233132151ffd0f546fef9e7cf26e6db8d1d59e Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 15 Nov 2024 01:06:10 +0800 Subject: [PATCH 26/36] feat(llm-panel): refine variable filtering logic Introduce `filterJinjia2InputVar` to enhance variable filtering, specifically excluding `arrayFile` types from Jinja2 input variables. This adjustment improves the management of variable types, aligning with expected input capacities and ensuring more reliable configurations. Additionally, support for file variables is enabled in relevant components, broadening functionality and user options. --- web/app/components/workflow/nodes/llm/panel.tsx | 4 +++- web/app/components/workflow/nodes/llm/use-config.ts | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index 76607b29b12e04..1def75cdf7cd18 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -67,6 +67,7 @@ const Panel: FC> = ({ handleStop, varInputs, runResult, + filterJinjia2InputVar, } = useConfig(id, data) const model = inputs.model @@ -194,7 +195,7 @@ const Panel: FC> = ({ list={inputs.prompt_config?.jinja2_variables || []} onChange={handleVarListChange} onVarNameChange={handleVarNameChange} - filterVar={filterVar} + filterVar={filterJinjia2InputVar} /> )} @@ -233,6 +234,7 @@ const Panel: FC> = ({ hasSetBlockStatus={hasSetBlockStatus} nodesOutputVars={availableVars} availableNodes={availableNodesWithParent} + isSupportFileVar /> {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index 1b84f811101ac3..dd550d7ba865ac 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -281,6 +281,10 @@ const useConfig = (id: string, payload: LLMNodeType) => { return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type) }, []) + const filterJinjia2InputVar = useCallback((varPayload: Var) => { + return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type) + }, []) + const filterMemoryPromptVar = useCallback((varPayload: Var) => { return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type) }, []) @@ -406,6 +410,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { handleRun, handleStop, runResult, + filterJinjia2InputVar, } } From 8039511c94a87e472fb8e4c67d4fe062a249f1ad Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 18 Nov 2024 15:23:51 +0800 Subject: [PATCH 27/36] fix(api/core/app/task_pipeline/workflow_cycle_manage.py) workflow session management Replaces direct database operations with SQLAlchemy Session context to manage workflow_run more securely and effectively. --- api/core/app/task_pipeline/workflow_cycle_manage.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index ae62e978e4e5e1..9d776f6337bccf 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -217,10 +217,12 @@ def _handle_workflow_run_failed( ).total_seconds() db.session.commit() - db.session.add(workflow_run) - db.session.refresh(workflow_run) db.session.close() + with Session(db.engine, expire_on_commit=False) as session: + session.add(workflow_run) + session.refresh(workflow_run) + if trace_manager: trace_manager.add_trace_task( TraceTask( From f83b7759452aae0c588813e0b4fe1e978e37e0e1 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 18 Nov 2024 15:28:30 +0800 Subject: [PATCH 28/36] feat(entities): add document prompt message content Introduces a new DocumentPromptMessageContent class to extend the variety of supported prompt message content types. This enhancement allows encoding document data with specific formats and handling them as part of prompt messages, improving versatility in content manipulation. --- api/core/model_runtime/entities/__init__.py | 2 ++ api/core/model_runtime/entities/message_entities.py | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index f5d4427e3e7a72..5e52f10b4c6ee2 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -2,6 +2,7 @@ from .message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, + DocumentPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContent, @@ -37,4 +38,5 @@ "LLMResultChunk", "LLMResultChunkDelta", "AudioPromptMessageContent", + "DocumentPromptMessageContent", ] diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index d4d56a42a4fa84..a7e3db0032626e 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,7 +1,7 @@ from abc import ABC from collections.abc import Sequence from enum import Enum -from typing import Optional +from typing import Literal, Optional from pydantic import BaseModel, Field, field_validator @@ -103,6 +103,13 @@ class DETAIL(str, Enum): detail: DETAIL = DETAIL.LOW +class DocumentPromptMessageContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.DOCUMENT + encode_format: Literal["base64"] + mime_type: str + data: str + + class PromptMessage(ABC, BaseModel): """ Model class for prompt message. From 313454e4d07200428c632d4b087edf210454af17 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 18 Nov 2024 15:28:46 +0800 Subject: [PATCH 29/36] feat(api): add document support in prompt message content Introduces support for document files in prompt message content conversion. Refactors encoding logic by unifying base64 encoding, simplifying and removing redundancy. Improves flexibility and maintainability of file handling in preparation for expanded multimedia support. --- api/core/file/file_manager.py | 67 ++++++++++++----------------------- 1 file changed, 23 insertions(+), 44 deletions(-) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 0b34349ba589c1..6d8086435d5b29 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -3,7 +3,12 @@ from configs import dify_config from core.file import file_repository from core.helper import ssrf_proxy -from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent +from core.model_runtime.entities import ( + AudioPromptMessageContent, + DocumentPromptMessageContent, + ImagePromptMessageContent, + VideoPromptMessageContent, +) from extensions.ext_database import db from extensions.ext_storage import storage @@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute): return file.remote_url case FileAttribute.EXTENSION: return file.extension - case _: - raise ValueError(f"Invalid file attribute: {attr}") def to_prompt_message_content( f: File, /, *, - image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ): - """ - Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object. - - This function takes a File object and converts it to an appropriate PromptMessageContent - object, which can be used as a prompt for image or audio-based AI models. - - Args: - f (File): The File object to convert. - detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts. - If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW. - - Returns: - Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level - - Raises: - ValueError: If the file type is not supported or if required data is missing. - """ match f.type: case FileType.IMAGE: + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": data = _to_url(f) else: @@ -65,7 +52,7 @@ def to_prompt_message_content( return ImagePromptMessageContent(data=data, detail=image_detail_config) case FileType.AUDIO: - encoded_string = _file_to_encoded_string(f) + encoded_string = _get_encoded_string(f) if f.extension is None: raise ValueError("Missing file extension") return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) @@ -77,8 +64,17 @@ def to_prompt_message_content( if f.extension is None: raise ValueError("Missing file extension") return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) + case FileType.DOCUMENT: + data = _get_encoded_string(f) + if f.mime_type is None: + raise ValueError("Missing file mime_type") + return DocumentPromptMessageContent( + encode_format="base64", + mime_type=f.mime_type, + data=data, + ) case _: - raise ValueError("file type f.type is not supported") + raise ValueError(f"file type {f.type} is not supported") def download(f: File, /): @@ -120,21 +116,16 @@ def _get_encoded_string(f: File, /): case FileTransferMethod.REMOTE_URL: response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() - content = response.content - encoded_string = base64.b64encode(content).decode("utf-8") - return encoded_string + data = response.content case FileTransferMethod.LOCAL_FILE: upload_file = file_repository.get_upload_file(session=db.session(), file=f) data = _download_file_content(upload_file.key) - encoded_string = base64.b64encode(data).decode("utf-8") - return encoded_string case FileTransferMethod.TOOL_FILE: tool_file = file_repository.get_tool_file(session=db.session(), file=f) data = _download_file_content(tool_file.file_key) - encoded_string = base64.b64encode(data).decode("utf-8") - return encoded_string - case _: - raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string def _to_base64_data_string(f: File, /): @@ -142,18 +133,6 @@ def _to_base64_data_string(f: File, /): return f"data:{f.mime_type};base64,{encoded_string}" -def _file_to_encoded_string(f: File, /): - match f.type: - case FileType.IMAGE: - return _to_base64_data_string(f) - case FileType.VIDEO: - return _to_base64_data_string(f) - case FileType.AUDIO: - return _get_encoded_string(f) - case _: - raise ValueError(f"file type {f.type} is not supported") - - def _to_url(f: File, /): if f.transfer_method == FileTransferMethod.REMOTE_URL: if f.remote_url is None: From 72258db1b6dcfadd270225182c709b634b5ead8d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 18 Nov 2024 15:29:16 +0800 Subject: [PATCH 30/36] feat(api): support document file type in message handling Extends file type handling to include documents in message processing. This enhances the application's ability to process a wider range of files. --- api/core/workflow/nodes/llm/node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 6963d4327f1f9b..f86824010032e5 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -888,14 +888,14 @@ def _handle_list_messages( for segment in segment_group.value: if isinstance(segment, ArrayFileSegment): for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: file_content = file_manager.to_prompt_message_content( file, image_detail_config=vision_detail_config ) file_contents.append(file_content) if isinstance(segment, FileSegment): file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: file_content = file_manager.to_prompt_message_content( file, image_detail_config=vision_detail_config ) From 9692d57382d2dd350775676beecdf1770cdc8023 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 18 Nov 2024 15:30:02 +0800 Subject: [PATCH 31/36] feat(llm): add document support for message prompts Introduces support for handling document content, specifically PDFs within prompt messages, enhancing model capabilities with a new feature. Allows dynamic configuration of headers based on document presence in prompts, improving flexibility for user interactions. --- .../llm/claude-3-5-sonnet-20240620.yaml | 1 + .../llm/claude-3-5-sonnet-20241022.yaml | 1 + .../model_providers/anthropic/llm/llm.py | 36 +++++++++++++++---- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml index e02c5517fe1f3c..4eb56bbc0e916e 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml index e20b8c4960734c..81822b162e6a16 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 4e7faab891e37a..79701e4ea4f547 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,7 +1,7 @@ import base64 import io import json -from collections.abc import Generator +from collections.abc import Generator, Sequence from typing import Optional, Union, cast import anthropic @@ -21,9 +21,9 @@ from PIL import Image from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + DocumentPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -33,6 +33,7 @@ ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -86,10 +87,10 @@ def _chat_generate( self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> Union[LLMResult, Generator]: @@ -130,9 +131,17 @@ def _chat_generate( # Add the new header for claude-3-5-sonnet-20240620 model extra_headers = {} if model == "claude-3-5-sonnet-20240620": - if model_parameters.get("max_tokens") > 4096: + if model_parameters.get("max_tokens", 0) > 4096: extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" + if any( + isinstance(content, DocumentPromptMessageContent) + for prompt_message in prompt_messages + if isinstance(prompt_message.content, list) + for content in prompt_message.content + ): + extra_headers["anthropic-beta"] = "pdfs-2024-09-25" + if tools: extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] response = client.beta.tools.messages.create( @@ -504,6 +513,21 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) + elif isinstance(message_content, DocumentPromptMessageContent): + if message_content.mime_type != "application/pdf": + raise ValueError( + f"Unsupported document type {message_content.mime_type}, " + "only support application/pdf" + ) + sub_message_dict = { + "type": "document", + "source": { + "type": message_content.encode_format, + "media_type": message_content.mime_type, + "data": message_content.data, + }, + } + sub_messages.append(sub_message_dict) prompt_message_dicts.append({"role": "user", "content": sub_messages}) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) From ed00f7bdb0f16e6cf64d2ead5afb0e71362f9b5b Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 18 Nov 2024 15:44:26 +0800 Subject: [PATCH 32/36] fix: remove redundant exception message Removes the exception message content duplication in the logger to prevent unnecessary redundancy since the exception details are already captured by logger.exception. --- api/core/workflow/nodes/llm/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index f86824010032e5..0cb53ee9d32168 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -191,7 +191,7 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] ) return except Exception as e: - logger.exception(f"Node {self.node_id} failed to run: {e}") + logger.exception(f"Node {self.node_id} failed to run") yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, From 02ec334bde84a7ce71ec7bfba9cd29a923946479 Mon Sep 17 00:00:00 2001 From: Joel Date: Tue, 19 Nov 2024 10:41:44 +0800 Subject: [PATCH 33/36] feat: code editor and code input can not insert file type vars --- .../components/editor/code-editor/editor-support-vars.tsx | 1 + .../workflow/nodes/_base/components/variable/var-list.tsx | 3 +++ .../nodes/_base/components/variable/var-reference-picker.tsx | 3 +++ .../nodes/_base/components/variable/var-reference-popup.tsx | 4 +++- web/app/components/workflow/nodes/code/panel.tsx | 1 + .../components/workflow/nodes/template-transform/panel.tsx | 1 + 6 files changed, 12 insertions(+), 1 deletion(-) diff --git a/web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx b/web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx index 6ca3af958a18bf..db2425c958b54e 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx @@ -160,6 +160,7 @@ const CodeEditor: FC = ({ hideSearch vars={availableVars} onChange={handleSelectVar} + isSupportFileVar={false} /> )} diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx index c447bb463b2856..fe2bb209877c86 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-list.tsx @@ -18,6 +18,7 @@ type Props = { isSupportConstantValue?: boolean onlyLeafNodeVar?: boolean filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean + isSupportFileVar?: boolean } const VarList: FC = ({ @@ -29,6 +30,7 @@ const VarList: FC = ({ isSupportConstantValue, onlyLeafNodeVar, filterVar, + isSupportFileVar = true, }) => { const { t } = useTranslation() @@ -94,6 +96,7 @@ const VarList: FC = ({ defaultVarKindType={item.variable_type} onlyLeafNodeVar={onlyLeafNodeVar} filterVar={filterVar} + isSupportFileVar={isSupportFileVar} /> {!readonly && ( void typePlaceHolder?: string + isSupportFileVar?: boolean } const VarReferencePicker: FC = ({ @@ -81,6 +82,7 @@ const VarReferencePicker: FC = ({ isInTable, onRemove, typePlaceHolder, + isSupportFileVar = true, }) => { const { t } = useTranslation() const store = useStoreApi() @@ -382,6 +384,7 @@ const VarReferencePicker: FC = ({ vars={outputVars} onChange={handleVarReferenceChange} itemWidth={isAddBtnTrigger ? 260 : triggerWidth} + isSupportFileVar={isSupportFileVar} /> )} diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx index 8ee9698745d096..cd03da1556f075 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx @@ -8,11 +8,13 @@ type Props = { vars: NodeOutPutVar[] onChange: (value: ValueSelector, varDetail: Var) => void itemWidth?: number + isSupportFileVar?: boolean } const VarReferencePopup: FC = ({ vars, onChange, itemWidth, + isSupportFileVar = true, }) => { // max-h-[300px] overflow-y-auto todo: use portal to handle long list return ( @@ -24,7 +26,7 @@ const VarReferencePopup: FC = ({ vars={vars} onChange={onChange} itemWidth={itemWidth} - isSupportFileVar + isSupportFileVar={isSupportFileVar} /> ) diff --git a/web/app/components/workflow/nodes/code/panel.tsx b/web/app/components/workflow/nodes/code/panel.tsx index 08fc565836b3a8..a0027daf53ae06 100644 --- a/web/app/components/workflow/nodes/code/panel.tsx +++ b/web/app/components/workflow/nodes/code/panel.tsx @@ -89,6 +89,7 @@ const Panel: FC> = ({ list={inputs.variables} onChange={handleVarListChange} filterVar={filterVar} + isSupportFileVar={false} /> diff --git a/web/app/components/workflow/nodes/template-transform/panel.tsx b/web/app/components/workflow/nodes/template-transform/panel.tsx index c02b895840b341..9eb52c55724eb3 100644 --- a/web/app/components/workflow/nodes/template-transform/panel.tsx +++ b/web/app/components/workflow/nodes/template-transform/panel.tsx @@ -64,6 +64,7 @@ const Panel: FC> = ({ onChange={handleVarListChange} onVarNameChange={handleVarNameChange} filterVar={filterVar} + isSupportFileVar={false} /> From 5882cdc9fa3e1f5c3299073ebd48f9ffbf3effa6 Mon Sep 17 00:00:00 2001 From: Joel Date: Tue, 19 Nov 2024 10:46:52 +0800 Subject: [PATCH 34/36] chore: jinja import not choose file --- web/app/components/workflow/nodes/llm/panel.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index 1def75cdf7cd18..ca54d4e48729dd 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -196,6 +196,7 @@ const Panel: FC> = ({ onChange={handleVarListChange} onVarNameChange={handleVarNameChange} filterVar={filterJinjia2InputVar} + isSupportFileVar={false} /> )} From 3528b2d9435c543eff9e1eacc6c587fa60ab4bb3 Mon Sep 17 00:00:00 2001 From: Joel Date: Tue, 19 Nov 2024 10:49:45 +0800 Subject: [PATCH 35/36] chore: use query not support file var --- web/app/components/workflow/nodes/llm/panel.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index ca54d4e48729dd..ac1ef4f6285e4b 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -235,7 +235,7 @@ const Panel: FC> = ({ hasSetBlockStatus={hasSetBlockStatus} nodesOutputVars={availableVars} availableNodes={availableNodesWithParent} - isSupportFileVar + isSupportFileVar={false} /> {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( From 0619e9aaeadcdff0ba755bf30a38aa22bc9f70b2 Mon Sep 17 00:00:00 2001 From: Joel Date: Tue, 19 Nov 2024 10:55:53 +0800 Subject: [PATCH 36/36] chore: use query prompt support file type var --- web/app/components/workflow/nodes/llm/panel.tsx | 2 +- web/app/components/workflow/nodes/llm/use-config.ts | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index ac1ef4f6285e4b..ca54d4e48729dd 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -235,7 +235,7 @@ const Panel: FC> = ({ hasSetBlockStatus={hasSetBlockStatus} nodesOutputVars={availableVars} availableNodes={availableNodesWithParent} - isSupportFileVar={false} + isSupportFileVar /> {inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && ( diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index dd550d7ba865ac..ee9f2ca9153005 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -278,7 +278,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { }, [inputs, setInputs]) const filterInputVar = useCallback((varPayload: Var) => { - return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type) + return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type) }, []) const filterJinjia2InputVar = useCallback((varPayload: Var) => { @@ -286,7 +286,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { }, []) const filterMemoryPromptVar = useCallback((varPayload: Var) => { - return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type) + return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.file, VarType.arrayFile].includes(varPayload.type) }, []) const {