Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions api/core/workflow/nodes/node_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict

from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
Expand All @@ -37,6 +42,7 @@ def __init__(
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
Expand All @@ -54,6 +60,7 @@ def __init__(
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()

@override
def create_node(self, node_config: dict[str, object]) -> Node:
Expand Down Expand Up @@ -107,6 +114,15 @@ def create_node(self, node_config: dict[str, object]) -> Node:
code_limits=self._code_limits,
)

if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
)

return node_class(
id=node_id,
config=node_config,
Expand Down
40 changes: 40 additions & 0 deletions api/core/workflow/nodes/template_transform/template_renderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import Any, Protocol

from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage


class TemplateRenderError(ValueError):
"""Raised when rendering a Jinja2 template fails."""


class Jinja2TemplateRenderer(Protocol):
"""Render Jinja2 templates for template transform nodes."""

def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
"""Render a Jinja2 template with provided variables."""
raise NotImplementedError


class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Adapter that renders Jinja2 templates via CodeExecutor."""

_code_executor: type[CodeExecutor]

def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
self._code_executor = code_executor or CodeExecutor

def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
try:
result = self._code_executor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=template, inputs=variables
)
except CodeExecutionError as exc:
raise TemplateRenderError(str(exc)) from exc

rendered = result.get("result")
if not isinstance(rendered, str):
raise TemplateRenderError("Template render result must be a string.")
return rendered
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
from collections.abc import Mapping, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any

from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
TemplateRenderError,
)

if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState

MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH


class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer

def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()

@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
Expand All @@ -39,21 +65,19 @@ def _run(self) -> NodeRunResult:
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
rendered = self._template_renderer.render_template(self.node_data.template, variables)
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))

if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters",
)

return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState

from core.helper.code_executor.code_executor import CodeExecutionError
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus
from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from models.workflow import WorkflowType

Expand Down Expand Up @@ -127,7 +127,9 @@ def test_version(self):
"""Test version class method."""
assert TemplateTransformNode.version() == "1"

@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_simple_template(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
Expand All @@ -145,7 +147,7 @@ def test_run_simple_template(
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))

# Setup mock executor
mock_execute.return_value = {"result": "Hello Alice, you are 30 years old!"}
mock_execute.return_value = "Hello Alice, you are 30 years old!"

node = TemplateTransformNode(
id="test_node",
Expand All @@ -162,7 +164,9 @@ def test_run_simple_template(
assert result.inputs["name"] == "Alice"
assert result.inputs["age"] == 30

@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with None variable values."""
node_data = {
Expand All @@ -172,7 +176,7 @@ def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime
}

mock_graph_runtime_state.variable_pool.get.return_value = None
mock_execute.return_value = {"result": "Value: "}
mock_execute.return_value = "Value: "

node = TemplateTransformNode(
id="test_node",
Expand All @@ -187,13 +191,15 @@ def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.inputs["value"] is None

@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_code_execution_error(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run when code execution fails."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.side_effect = CodeExecutionError("Template syntax error")
mock_execute.side_effect = TemplateRenderError("Template syntax error")

node = TemplateTransformNode(
id="test_node",
Expand All @@ -208,14 +214,16 @@ def test_run_with_code_execution_error(
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Template syntax error" in result.error

@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
@patch("core.workflow.nodes.template_transform.template_transform_node.MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH", 10)
def test_run_output_length_exceeds_limit(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
):
"""Test _run when output exceeds maximum length."""
mock_graph_runtime_state.variable_pool.get.return_value = MagicMock()
mock_execute.return_value = {"result": "This is a very long output that exceeds the limit"}
mock_execute.return_value = "This is a very long output that exceeds the limit"

node = TemplateTransformNode(
id="test_node",
Expand All @@ -230,7 +238,9 @@ def test_run_output_length_exceeds_limit(
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Output length exceeds" in result.error

@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_complex_jinja2_template(
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
):
Expand All @@ -257,7 +267,7 @@ def test_run_with_complex_jinja2_template(
("sys", "show_total"): mock_show_total,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = {"result": "apple, banana, orange (Total: 3)"}
mock_execute.return_value = "apple, banana, orange (Total: 3)"

node = TemplateTransformNode(
id="test_node",
Expand Down Expand Up @@ -292,7 +302,9 @@ def test_extract_variable_selector_to_variable_mapping(self):
assert mapping["node_123.var1"] == ["sys", "input1"]
assert mapping["node_123.var2"] == ["sys", "input2"]

@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with no variables (static template)."""
node_data = {
Expand All @@ -301,7 +313,7 @@ def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_run
"template": "This is a static message.",
}

mock_execute.return_value = {"result": "This is a static message."}
mock_execute.return_value = "This is a static message."

node = TemplateTransformNode(
id="test_node",
Expand All @@ -317,7 +329,9 @@ def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_run
assert result.outputs["output"] == "This is a static message."
assert result.inputs == {}

@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with numeric variable values."""
node_data = {
Expand All @@ -339,7 +353,7 @@ def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runt
("sys", "quantity"): mock_quantity,
}
mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector))
mock_execute.return_value = {"result": "Total: $31.5"}
mock_execute.return_value = "Total: $31.5"

node = TemplateTransformNode(
id="test_node",
Expand All @@ -354,7 +368,9 @@ def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runt
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["output"] == "Total: $31.5"

@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with dictionary variable values."""
node_data = {
Expand All @@ -367,7 +383,7 @@ def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime
mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"}

mock_graph_runtime_state.variable_pool.get.return_value = mock_user
mock_execute.return_value = {"result": "Name: John Doe, Email: john@example.com"}
mock_execute.return_value = "Name: John Doe, Email: john@example.com"

node = TemplateTransformNode(
id="test_node",
Expand All @@ -383,7 +399,9 @@ def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime
assert "John Doe" in result.outputs["output"]
assert "john@example.com" in result.outputs["output"]

@patch("core.workflow.nodes.template_transform.template_transform_node.CodeExecutor.execute_workflow_code_template")
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
)
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with list variable values."""
node_data = {
Expand All @@ -396,7 +414,7 @@ def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime
mock_tags.to_object.return_value = ["python", "ai", "workflow"]

mock_graph_runtime_state.variable_pool.get.return_value = mock_tags
mock_execute.return_value = {"result": "Tags: #python #ai #workflow "}
mock_execute.return_value = "Tags: #python #ai #workflow "

node = TemplateTransformNode(
id="test_node",
Expand Down
Loading