Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions task-sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from inspect import signature
from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, cast, overload
from urllib.parse import urlsplit
from uuid import UUID

import attrs
import jinja2
Expand Down Expand Up @@ -1389,13 +1390,13 @@ def _run_task(
# it is run.
ti.set_state(TaskInstanceState.QUEUED)
task_sdk_ti = TaskInstanceSDK(
id=ti.id,
id=UUID(str(ti.id)),
task_id=ti.task_id,
dag_id=ti.dag_id,
run_id=ti.run_id,
try_number=ti.try_number,
map_index=ti.map_index,
dag_version_id=ti.dag_version_id,
dag_version_id=UUID(str(ti.dag_version_id)),
)

taskrun_result = run_task_in_process(ti=task_sdk_ti, task=task)
Expand All @@ -1414,7 +1415,15 @@ def _run_task(
trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
event = _run_inline_trigger(trigger, task_sdk_ti)
ti.next_method = msg.next_method
ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs

# Deserialize next_kwargs if it's a string (encrypted dict), similar to what the API server does
next_kwargs_value = {"event": event.payload} if event else msg.next_kwargs
if isinstance(next_kwargs_value, str):
from airflow.serialization.serialized_objects import BaseSerialization

ti.next_kwargs = BaseSerialization.deserialize(next_kwargs_value)
else:
ti.next_kwargs = next_kwargs_value
Comment on lines +1420 to +1426
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why this is needed? it feels like a behaviour change, not just a typing change

Copy link
Collaborator Author

@Dev-iL Dev-iL Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for the change is the following mypy violation (in the presence of SQLA2):

task-sdk/src/airflow/sdk/definitions/dag.py:1417: error: Incompatible types in assignment (expression has type
"dict[str, Any] | str | None", variable has type "SQLCoreOperations[dict[Any, Any] | None] | dict[Any, Any] | None")  [assignment]
                    ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs
                                     ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The explanation of the change is here (see "The Issue" and "The Correct Solution" sections).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What error was this showing without this change? This code appears to have been working fine as triggers are functional, (and I'm reasonably sure they are currently being encrypted as strings), so I'm not clear why this is needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's because the type on the model is wrong then: SQLCoreOperations[dict[Any, Any] | None] should be dict[str,Any], not dict[Any, Any]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy has an issue with the | str | case of

# airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py#331
    next_kwargs: dict[str, Any] | str | None = None
    """
    Args to pass to ``next_method``.

    Can either be a "decorated" dict, or a string encrypted with the shared Fernet key.
    """

not the dict. TP remarked that perhaps this should never be a dict at all.


I let the LLM analyze it some more and it came up with this:

## The Issue
The original mypy error occurred because `ti.next_kwargs` expects `dict | None`, but `msg.next_kwargs` from `DeferTask` can be `dict | str | None` (where the `str` represents an encrypted value).

## The Solution
Update the deserialization logic in `dag.py` to properly handle all cases:

- None → assign `None` directly
- Serialized dict (with `__type` and `__var` keys) → call `BaseSerialization.deserialize()`
- Regular dict → use as-is
- Encrypted string → log warning and assign `None` (encryption not supported in `dag.test()` context)

## Key Discovery
`BaseSerialization.serialize()` returns a dict with special keys (not a JSON string), and `deserialize()` expects this dict format. The "encrypted string" case mentioned in comments refers to when Fernet encryption is enabled, which doesn't apply to `dag.test()`.
So now it thinks the code should instead look like this:

# Deserialize next_kwargs if needed, matching what the API server does
next_kwargs_value = {"event": event.payload} if event else msg.next_kwargs
if next_kwargs_value is None:
    ti.next_kwargs = None
elif isinstance(next_kwargs_value, dict):
    if set(next_kwargs_value.keys()) == {"__type", "__var"}:
        # Serialized format - deserialize it
        from airflow.serialization.serialized_objects import BaseSerialization

        ti.next_kwargs = BaseSerialization.deserialize(next_kwargs_value)
    else:
        # Regular dict - use as-is
        ti.next_kwargs = next_kwargs_value
else:
    # String (encrypted) - in dag.test() context, encryption is not used,
    # but we need to handle this for type checking
    # The API server would decrypt this before calling deserialize
    if not isinstance(next_kwargs_value, str):
        raise TypeError(f"Unexpected type for next_kwargs: {type(next_kwargs_value)}")
    # For now, we can't decrypt without the Fernet key, so log a warning
    log.warning(
        "[DAG TEST] Received encrypted next_kwargs string, cannot decrypt in dag.test() context"
    )
    # Type-ignore needed because we can't decrypt in dag.test()
    ti.next_kwargs = None  # type: ignore[assignment]

...with a couple of accompanying unit tests that illustrate the different behavior of the existing and proposed implementations:

class TestDeferredTaskNextKwargs:
    """Test that next_kwargs is properly deserialized when it's an encrypted string."""

    @pytest.fixture
    def mock_task_instance(self, mocker):
        """Create a mock scheduler TaskInstance."""
        from airflow.sdk import TaskInstanceState

        ti = mocker.MagicMock()
        ti.task_id = "test_task"
        ti.dag_id = "test_dag"
        ti.run_id = "test_run"
        ti.map_index = -1
        ti.id = "123e4567-e89b-12d3-a456-426614174000"
        ti.dag_version_id = "223e4567-e89b-12d3-a456-426614174000"
        ti.try_number = 1
        ti.state = TaskInstanceState.DEFERRED
        return ti

    @pytest.fixture
    def mock_task(self, mocker):
        """Create a mock task."""

        task = mocker.MagicMock()
        task.task_id = "test_task"
        return task

    def test_next_kwargs_deserialized_when_encrypted_string(
        self, mock_task_instance, mock_task, monkeypatch, mocker,
    ):
        """
        Test that when msg.next_kwargs is a serialized string, it gets properly deserialized to a dict.

        This simulates the case where DeferTask contains an encrypted/serialized next_kwargs.
        The bug we're testing: without the fix, string next_kwargs would be assigned directly to ti.next_kwargs,
        causing a type error. With the fix, it should be deserialized to a dict.
        """
        from airflow.sdk.definitions.dag import _run_task
        from airflow.sdk.execution_time.comms import DeferTask
        from airflow.serialization.serialized_objects import BaseSerialization

        # Setup: Create a dict that will be serialized (simulating real behavior)
        original_kwargs = {"custom_param": "value", "count": 42}
        # Serialize it to get the format with __type and __var keys
        serialized_kwargs = BaseSerialization.serialize(original_kwargs)

        # serialized_kwargs should now be a dict like:
        # {'__type': 'dict', '__var': {'custom_param': 'value', 'count': 42}}

        mock_defer_task = DeferTask(
            classpath="airflow.triggers.base.BaseTrigger",
            trigger_kwargs={},
            next_method="execute_complete",
            next_kwargs=serialized_kwargs,
        )

        # Create a mock TaskRunResult with the DeferTask message
        mock_task_run_result = mocker.MagicMock()
        mock_task_run_result.msg = mock_defer_task
        mock_task_run_result.ti.state = "deferred"
        mock_task_run_result.ti.task = mock_task

        # Mock the run_task_in_process to return our prepared result
        mock_run_task = mocker.MagicMock(return_value=mock_task_run_result)
        monkeypatch.setattr(
            "airflow.sdk.execution_time.supervisor.run_task_in_process",
            mock_run_task,
            raising=False,
        )

        # Mock _run_inline_trigger to return no event (so msg.next_kwargs is used)
        mock_inline_trigger = mocker.MagicMock(return_value=None)
        monkeypatch.setattr(
            "airflow.sdk.definitions.dag._run_inline_trigger",
            mock_inline_trigger,
            raising=False,
        )

        # Mock create_scheduler_operator
        monkeypatch.setattr(
            "airflow.serialization.serialized_objects.create_scheduler_operator",
            mocker.MagicMock(return_value=mock_task),
            raising=False,
        )

        # Mock import_string to return a mock trigger
        mock_trigger = mocker.MagicMock()
        monkeypatch.setattr(
            "airflow.sdk.module_loading.import_string",
            mocker.MagicMock(return_value=lambda **kwargs: mock_trigger),
            raising=False,
        )

        # Mock create_session
        mock_session = mocker.MagicMock()
        mock_create_session = mocker.MagicMock()
        mock_create_session.__enter__ = mocker.MagicMock(return_value=mock_session)
        mock_create_session.__exit__ = mocker.MagicMock(return_value=False)
        monkeypatch.setattr(
            "airflow.utils.session.create_session",
            mocker.MagicMock(return_value=mock_create_session),
            raising=False,
        )

        # Track assignments to next_kwargs using a property descriptor
        assigned_values = []
        original_next_kwargs = mock_task_instance.next_kwargs

        def track_next_kwargs_setter(value):
            assigned_values.append(value)
            # Update the mock's return value for the property
            type(mock_task_instance).next_kwargs = mocker.PropertyMock(return_value=value)

        # Replace next_kwargs with a property that tracks assignments
        type(mock_task_instance).next_kwargs = mocker.PropertyMock(
            return_value=original_next_kwargs,
            side_effect=lambda: assigned_values[-1] if assigned_values else original_next_kwargs
        )
        mocker.patch.object(
            type(mock_task_instance),
            'next_kwargs',
            new_callable=mocker.PropertyMock,
            return_value=original_next_kwargs
        )

        # We need to use __setattr__ to capture the assignment
        original_setattr = type(mock_task_instance).__setattr__

        def capturing_setattr(obj, name, value):
            if name == 'next_kwargs':
                assigned_values.append(value)
                # Store it in the mock's internal state
                object.__setattr__(obj, '_next_kwargs_value', value)
            else:
                original_setattr(obj, name, value)

        type(mock_task_instance).__setattr__ = capturing_setattr

        # Execute _run_task with run_triggerer=True to trigger the deferred path
        _run_task(ti=mock_task_instance, task=mock_task, run_triggerer=True)

        # Restore original behavior
        type(mock_task_instance).__setattr__ = original_setattr

        # Verify: The key behavior is that next_kwargs is now a dict, not a string
        # This is what the fix ensures - regardless of HOW it's done
        assert len(assigned_values) > 0, "next_kwargs should have been assigned at least once"
        final_value = assigned_values[-1]  # Get the last assigned value

        assert isinstance(final_value, dict), (
            f"next_kwargs should be a dict after deserialization, not {type(final_value).__name__}"
        )
        # Verify the dict contains the expected keys (proves it was properly deserialized)
        assert "custom_param" in final_value
        assert "count" in final_value

    def test_next_kwargs_with_trigger_event(self, mock_task_instance, mock_task, monkeypatch, mocker):
        """
        Test that when a trigger returns an event, event.payload is used for next_kwargs.

        This verifies that the event path takes precedence over msg.next_kwargs.
        """
        from airflow.sdk.definitions.dag import _run_task
        from airflow.sdk.execution_time.comms import DeferTask

        # Create a mock event with payload
        mock_event = mocker.MagicMock()
        mock_event.payload = {"event_data": "from_trigger", "timestamp": "2024-01-01"}

        # Create mock DeferTask with a string next_kwargs that should be ignored
        mock_defer_task = DeferTask(
            classpath="airflow.triggers.base.BaseTrigger",
            trigger_kwargs={},
            next_method="execute_complete",
            next_kwargs='{"should": "be_ignored"}',  # Should be ignored since we have an event
        )

        # Create mock TaskRunResult
        mock_task_run_result = mocker.MagicMock()
        mock_task_run_result.msg = mock_defer_task
        mock_task_run_result.ti.state = "deferred"
        mock_task_run_result.ti.task = mock_task

        # Mock dependencies
        monkeypatch.setattr(
            "airflow.sdk.execution_time.supervisor.run_task_in_process",
            mocker.MagicMock(return_value=mock_task_run_result),
            raising=False,
        )
        monkeypatch.setattr(
            "airflow.sdk.definitions.dag._run_inline_trigger",
            mocker.MagicMock(return_value=mock_event),
            raising=False,
        )
        monkeypatch.setattr(
            "airflow.serialization.serialized_objects.create_scheduler_operator",
            mocker.MagicMock(return_value=mock_task),
            raising=False,
        )
        monkeypatch.setattr(
            "airflow.sdk.module_loading.import_string",
            mocker.MagicMock(return_value=lambda **kwargs: mocker.MagicMock()),
            raising=False,
        )

        mock_session = mocker.MagicMock()
        mock_create_session = mocker.MagicMock()
        mock_create_session.__enter__ = mocker.MagicMock(return_value=mock_session)
        mock_create_session.__exit__ = mocker.MagicMock(return_value=False)
        monkeypatch.setattr(
            "airflow.utils.session.create_session",
            mocker.MagicMock(return_value=mock_create_session),
            raising=False,
        )

        # Execute
        _run_task(ti=mock_task_instance, task=mock_task, run_triggerer=True)

        # Verify: next_kwargs should be set to {"event": event.payload}, not msg.next_kwargs
        assert mock_task_instance.next_kwargs == {"event": mock_event.payload}
        assert isinstance(mock_task_instance.next_kwargs, dict)
        # Ensure msg.next_kwargs was not used
        assert mock_task_instance.next_kwargs.get("should") != "be_ignored"


I'll revert my change to this bit since I don't see how to solve this in an acceptable way.

log.info("[DAG TEST] Trigger completed")

# Set the state to SCHEDULED so that the task can be resumed.
Expand Down