Skip to content

Conversation

@Dev-iL
Copy link
Collaborator

@Dev-iL Dev-iL commented Nov 6, 2025

related: #56212, #58049

This PR fixes the below mypy violations that appear in task-sdk after upgrading to SQLA2:

task-sdk/src/airflow/sdk/definitions/dag.py:1392: error: Argument "id" to "TaskInstance" has incompatible type "str"; expected "UUID"  [arg-type]
                    id=ti.id,
                       ^~~~~

task-sdk/src/airflow/sdk/definitions/dag.py:1398: error: Argument "dag_version_id" to "TaskInstance" has incompatible type "str | UUID | None"; expected "UUID"  [arg-type]
                    dag_version_id=ti.dag_version_id,
                                   ^~~~~~~~~~~~~~~~~

task-sdk/src/airflow/sdk/definitions/dag.py:1417: error: Incompatible types in assignment (expression has type
"dict[str, Any] | str | None", variable has type "SQLCoreOperations[dict[Any, Any] | None] | dict[Any, Any] | None")  [assignment]
                    ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs
                                     ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

^ Add meaningful description above
Read the Pull Request Guidelines for more information.
In case of fundamental code changes, an Airflow Improvement Proposal (AIP) is needed.
In case of a new dependency, check compliance with the ASF 3rd Party License Policy.
In case of backwards incompatible changes please leave a note in a newsfragment file, named {pr_number}.significant.rst or {issue_number}.significant.rst, in airflow-core/newsfragments.

@Dev-iL Dev-iL force-pushed the 2511/mypy-tasksdk-dag branch 2 times, most recently from 9ab5dca to 826b7cc Compare November 6, 2025 10:17
@Dev-iL
Copy link
Collaborator Author

Dev-iL commented Nov 6, 2025

CC: @vincbeck @uranusjr

@ashb
Copy link
Member

ashb commented Nov 10, 2025

@Dev-iL This PR title doesn't make anysense to me though --- the TaskSDK has nothing to do with SQLa2. It (definitionally) can't load things from the DB...

@Dev-iL
Copy link
Collaborator Author

Dev-iL commented Nov 10, 2025

@ashb The context is removing the SQLA 1 limit in the FAB provider. These are fixes for the mypy-task-sdk job observed in #56212. What other title you think would be more suitable?

Copy link
Member

@ashb ashb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too many changes for just "mypy violations" to my mind. Introducing serialization when there was none before is well beyond that.

@ashb
Copy link
Member

ashb commented Nov 10, 2025

@ashb The context is removing the SQLA 1 limit in the FAB provider. These are fixes for the mypy-task-sdk job observed in #56212. What other title you think would be more suitable?

What does SQLA2 have to do with task sdk at all? We don't use it in the task sdk. That is why I'm confused

@Dev-iL Dev-iL force-pushed the 2511/mypy-tasksdk-dag branch from 826b7cc to 23d2ed4 Compare November 10, 2025 15:53
@Dev-iL
Copy link
Collaborator Author

Dev-iL commented Nov 10, 2025

What does SQLA2 have to do with task sdk at all? We don't use it in the task sdk. That is why I'm confused

To my understanding, mypy behaves differently in the presence of SQLA2 (https://docs.sqlalchemy.org/en/14/orm/extensions/mypy.html), and this happens to affect some type hints in task-sdk.

Regardless, I simplified this PR greatly - please see if it's acceptable now.

@Dev-iL Dev-iL requested a review from ashb November 10, 2025 16:01
@Dev-iL Dev-iL force-pushed the 2511/mypy-tasksdk-dag branch 2 times, most recently from 8b8f1db to 083ef8f Compare November 10, 2025 18:02
@potiuk
Copy link
Member

potiuk commented Nov 10, 2025

I guess MyPy will assume (with bad heuristics) that UUID is coming somewhere from SQLAlchemy. MyPy sometimes makes wrong guesses.

@potiuk potiuk force-pushed the 2511/mypy-tasksdk-dag branch from 083ef8f to 1703f84 Compare November 10, 2025 22:01
@potiuk
Copy link
Member

potiuk commented Nov 10, 2025

Rebased - there were problems caused by pytest 9.

Comment on lines 1417 to 1423
ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs
next_kwargs_value = {"event": event.payload} if event else msg.next_kwargs
ti.next_kwargs = (
next_kwargs_value
if isinstance(next_kwargs_value, dict) or next_kwargs_value is None
else None
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, if this is a str it’s an encrypted dict. I don’t think it’s supposed to be dropped. Is there a bug somewhere, or should ti.next_kwargs simply expect str?

Original commit fbbe59a

cc @ashb

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image Fixed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uranusjr Where would be the appropriate place to fix the root cause of this issue, if not in this file?

task-sdk/src/airflow/sdk/definitions/dag.py:1418: error: Incompatible types in assignment (expression has type
"dict[str, Any] | str | None", variable has type "SQLCoreOperations[dict[Any, Any] | None] | dict[Any, Any] | None")  [assignment]
                    ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs
                                     ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

@Dev-iL Dev-iL force-pushed the 2511/mypy-tasksdk-dag branch 3 times, most recently from 1a06efe to 3f24a02 Compare November 13, 2025 11:36
@Dev-iL Dev-iL requested a review from uranusjr November 13, 2025 11:41
@Dev-iL Dev-iL force-pushed the 2511/mypy-tasksdk-dag branch 2 times, most recently from 2e61997 to 7c0b7ce Compare November 13, 2025 15:03
@Dev-iL Dev-iL force-pushed the 2511/mypy-tasksdk-dag branch from 7c0b7ce to 9477864 Compare November 13, 2025 18:39
@Dev-iL
Copy link
Collaborator Author

Dev-iL commented Nov 14, 2025

@uranusjr are we good to merge this now?

@ashb
Copy link
Member

ashb commented Nov 14, 2025

Please change the pr title -- task sdk cant have sqla2 related changes since it doesnt use sqla2

Comment on lines +1420 to +1426
next_kwargs_value = {"event": event.payload} if event else msg.next_kwargs
if isinstance(next_kwargs_value, str):
from airflow.serialization.serialized_objects import BaseSerialization

ti.next_kwargs = BaseSerialization.deserialize(next_kwargs_value)
else:
ti.next_kwargs = next_kwargs_value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why this is needed? it feels like a behaviour change, not just a typing change

Copy link
Collaborator Author

@Dev-iL Dev-iL Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for the change is the following mypy violation (in the presence of SQLA2):

task-sdk/src/airflow/sdk/definitions/dag.py:1417: error: Incompatible types in assignment (expression has type
"dict[str, Any] | str | None", variable has type "SQLCoreOperations[dict[Any, Any] | None] | dict[Any, Any] | None")  [assignment]
                    ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs
                                     ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The explanation of the change is here (see "The Issue" and "The Correct Solution" sections).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What error was this showing without this change? This code appears to have been working fine as triggers are functional, (and I'm reasonably sure they are currently being encrypted as strings), so I'm not clear why this is needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's because the type on the model is wrong then: SQLCoreOperations[dict[Any, Any] | None] should be dict[str,Any], not dict[Any, Any]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy has an issue with the | str | case of

# airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py#331
    next_kwargs: dict[str, Any] | str | None = None
    """
    Args to pass to ``next_method``.

    Can either be a "decorated" dict, or a string encrypted with the shared Fernet key.
    """

not the dict. TP remarked that perhaps this should never be a dict at all.


I let the LLM analyze it some more and it came up with this:

## The Issue
The original mypy error occurred because `ti.next_kwargs` expects `dict | None`, but `msg.next_kwargs` from `DeferTask` can be `dict | str | None` (where the `str` represents an encrypted value).

## The Solution
Update the deserialization logic in `dag.py` to properly handle all cases:

- None → assign `None` directly
- Serialized dict (with `__type` and `__var` keys) → call `BaseSerialization.deserialize()`
- Regular dict → use as-is
- Encrypted string → log warning and assign `None` (encryption not supported in `dag.test()` context)

## Key Discovery
`BaseSerialization.serialize()` returns a dict with special keys (not a JSON string), and `deserialize()` expects this dict format. The "encrypted string" case mentioned in comments refers to when Fernet encryption is enabled, which doesn't apply to `dag.test()`.
So now it thinks the code should instead look like this:

# Deserialize next_kwargs if needed, matching what the API server does
next_kwargs_value = {"event": event.payload} if event else msg.next_kwargs
if next_kwargs_value is None:
    ti.next_kwargs = None
elif isinstance(next_kwargs_value, dict):
    if set(next_kwargs_value.keys()) == {"__type", "__var"}:
        # Serialized format - deserialize it
        from airflow.serialization.serialized_objects import BaseSerialization

        ti.next_kwargs = BaseSerialization.deserialize(next_kwargs_value)
    else:
        # Regular dict - use as-is
        ti.next_kwargs = next_kwargs_value
else:
    # String (encrypted) - in dag.test() context, encryption is not used,
    # but we need to handle this for type checking
    # The API server would decrypt this before calling deserialize
    if not isinstance(next_kwargs_value, str):
        raise TypeError(f"Unexpected type for next_kwargs: {type(next_kwargs_value)}")
    # For now, we can't decrypt without the Fernet key, so log a warning
    log.warning(
        "[DAG TEST] Received encrypted next_kwargs string, cannot decrypt in dag.test() context"
    )
    # Type-ignore needed because we can't decrypt in dag.test()
    ti.next_kwargs = None  # type: ignore[assignment]

...with a couple of accompanying unit tests that illustrate the different behavior of the existing and proposed implementations:

class TestDeferredTaskNextKwargs:
    """Test that next_kwargs is properly deserialized when it's an encrypted string."""

    @pytest.fixture
    def mock_task_instance(self, mocker):
        """Create a mock scheduler TaskInstance."""
        from airflow.sdk import TaskInstanceState

        ti = mocker.MagicMock()
        ti.task_id = "test_task"
        ti.dag_id = "test_dag"
        ti.run_id = "test_run"
        ti.map_index = -1
        ti.id = "123e4567-e89b-12d3-a456-426614174000"
        ti.dag_version_id = "223e4567-e89b-12d3-a456-426614174000"
        ti.try_number = 1
        ti.state = TaskInstanceState.DEFERRED
        return ti

    @pytest.fixture
    def mock_task(self, mocker):
        """Create a mock task."""

        task = mocker.MagicMock()
        task.task_id = "test_task"
        return task

    def test_next_kwargs_deserialized_when_encrypted_string(
        self, mock_task_instance, mock_task, monkeypatch, mocker,
    ):
        """
        Test that when msg.next_kwargs is a serialized string, it gets properly deserialized to a dict.

        This simulates the case where DeferTask contains an encrypted/serialized next_kwargs.
        The bug we're testing: without the fix, string next_kwargs would be assigned directly to ti.next_kwargs,
        causing a type error. With the fix, it should be deserialized to a dict.
        """
        from airflow.sdk.definitions.dag import _run_task
        from airflow.sdk.execution_time.comms import DeferTask
        from airflow.serialization.serialized_objects import BaseSerialization

        # Setup: Create a dict that will be serialized (simulating real behavior)
        original_kwargs = {"custom_param": "value", "count": 42}
        # Serialize it to get the format with __type and __var keys
        serialized_kwargs = BaseSerialization.serialize(original_kwargs)

        # serialized_kwargs should now be a dict like:
        # {'__type': 'dict', '__var': {'custom_param': 'value', 'count': 42}}

        mock_defer_task = DeferTask(
            classpath="airflow.triggers.base.BaseTrigger",
            trigger_kwargs={},
            next_method="execute_complete",
            next_kwargs=serialized_kwargs,
        )

        # Create a mock TaskRunResult with the DeferTask message
        mock_task_run_result = mocker.MagicMock()
        mock_task_run_result.msg = mock_defer_task
        mock_task_run_result.ti.state = "deferred"
        mock_task_run_result.ti.task = mock_task

        # Mock the run_task_in_process to return our prepared result
        mock_run_task = mocker.MagicMock(return_value=mock_task_run_result)
        monkeypatch.setattr(
            "airflow.sdk.execution_time.supervisor.run_task_in_process",
            mock_run_task,
            raising=False,
        )

        # Mock _run_inline_trigger to return no event (so msg.next_kwargs is used)
        mock_inline_trigger = mocker.MagicMock(return_value=None)
        monkeypatch.setattr(
            "airflow.sdk.definitions.dag._run_inline_trigger",
            mock_inline_trigger,
            raising=False,
        )

        # Mock create_scheduler_operator
        monkeypatch.setattr(
            "airflow.serialization.serialized_objects.create_scheduler_operator",
            mocker.MagicMock(return_value=mock_task),
            raising=False,
        )

        # Mock import_string to return a mock trigger
        mock_trigger = mocker.MagicMock()
        monkeypatch.setattr(
            "airflow.sdk.module_loading.import_string",
            mocker.MagicMock(return_value=lambda **kwargs: mock_trigger),
            raising=False,
        )

        # Mock create_session
        mock_session = mocker.MagicMock()
        mock_create_session = mocker.MagicMock()
        mock_create_session.__enter__ = mocker.MagicMock(return_value=mock_session)
        mock_create_session.__exit__ = mocker.MagicMock(return_value=False)
        monkeypatch.setattr(
            "airflow.utils.session.create_session",
            mocker.MagicMock(return_value=mock_create_session),
            raising=False,
        )

        # Track assignments to next_kwargs using a property descriptor
        assigned_values = []
        original_next_kwargs = mock_task_instance.next_kwargs

        def track_next_kwargs_setter(value):
            assigned_values.append(value)
            # Update the mock's return value for the property
            type(mock_task_instance).next_kwargs = mocker.PropertyMock(return_value=value)

        # Replace next_kwargs with a property that tracks assignments
        type(mock_task_instance).next_kwargs = mocker.PropertyMock(
            return_value=original_next_kwargs,
            side_effect=lambda: assigned_values[-1] if assigned_values else original_next_kwargs
        )
        mocker.patch.object(
            type(mock_task_instance),
            'next_kwargs',
            new_callable=mocker.PropertyMock,
            return_value=original_next_kwargs
        )

        # We need to use __setattr__ to capture the assignment
        original_setattr = type(mock_task_instance).__setattr__

        def capturing_setattr(obj, name, value):
            if name == 'next_kwargs':
                assigned_values.append(value)
                # Store it in the mock's internal state
                object.__setattr__(obj, '_next_kwargs_value', value)
            else:
                original_setattr(obj, name, value)

        type(mock_task_instance).__setattr__ = capturing_setattr

        # Execute _run_task with run_triggerer=True to trigger the deferred path
        _run_task(ti=mock_task_instance, task=mock_task, run_triggerer=True)

        # Restore original behavior
        type(mock_task_instance).__setattr__ = original_setattr

        # Verify: The key behavior is that next_kwargs is now a dict, not a string
        # This is what the fix ensures - regardless of HOW it's done
        assert len(assigned_values) > 0, "next_kwargs should have been assigned at least once"
        final_value = assigned_values[-1]  # Get the last assigned value

        assert isinstance(final_value, dict), (
            f"next_kwargs should be a dict after deserialization, not {type(final_value).__name__}"
        )
        # Verify the dict contains the expected keys (proves it was properly deserialized)
        assert "custom_param" in final_value
        assert "count" in final_value

    def test_next_kwargs_with_trigger_event(self, mock_task_instance, mock_task, monkeypatch, mocker):
        """
        Test that when a trigger returns an event, event.payload is used for next_kwargs.

        This verifies that the event path takes precedence over msg.next_kwargs.
        """
        from airflow.sdk.definitions.dag import _run_task
        from airflow.sdk.execution_time.comms import DeferTask

        # Create a mock event with payload
        mock_event = mocker.MagicMock()
        mock_event.payload = {"event_data": "from_trigger", "timestamp": "2024-01-01"}

        # Create mock DeferTask with a string next_kwargs that should be ignored
        mock_defer_task = DeferTask(
            classpath="airflow.triggers.base.BaseTrigger",
            trigger_kwargs={},
            next_method="execute_complete",
            next_kwargs='{"should": "be_ignored"}',  # Should be ignored since we have an event
        )

        # Create mock TaskRunResult
        mock_task_run_result = mocker.MagicMock()
        mock_task_run_result.msg = mock_defer_task
        mock_task_run_result.ti.state = "deferred"
        mock_task_run_result.ti.task = mock_task

        # Mock dependencies
        monkeypatch.setattr(
            "airflow.sdk.execution_time.supervisor.run_task_in_process",
            mocker.MagicMock(return_value=mock_task_run_result),
            raising=False,
        )
        monkeypatch.setattr(
            "airflow.sdk.definitions.dag._run_inline_trigger",
            mocker.MagicMock(return_value=mock_event),
            raising=False,
        )
        monkeypatch.setattr(
            "airflow.serialization.serialized_objects.create_scheduler_operator",
            mocker.MagicMock(return_value=mock_task),
            raising=False,
        )
        monkeypatch.setattr(
            "airflow.sdk.module_loading.import_string",
            mocker.MagicMock(return_value=lambda **kwargs: mocker.MagicMock()),
            raising=False,
        )

        mock_session = mocker.MagicMock()
        mock_create_session = mocker.MagicMock()
        mock_create_session.__enter__ = mocker.MagicMock(return_value=mock_session)
        mock_create_session.__exit__ = mocker.MagicMock(return_value=False)
        monkeypatch.setattr(
            "airflow.utils.session.create_session",
            mocker.MagicMock(return_value=mock_create_session),
            raising=False,
        )

        # Execute
        _run_task(ti=mock_task_instance, task=mock_task, run_triggerer=True)

        # Verify: next_kwargs should be set to {"event": event.payload}, not msg.next_kwargs
        assert mock_task_instance.next_kwargs == {"event": mock_event.payload}
        assert isinstance(mock_task_instance.next_kwargs, dict)
        # Ensure msg.next_kwargs was not used
        assert mock_task_instance.next_kwargs.get("should") != "be_ignored"


I'll revert my change to this bit since I don't see how to solve this in an acceptable way.

@Dev-iL Dev-iL changed the title SQLA2: fix mypy violations in task-sdk/.../dag.py Fix mypy violations in task-sdk/.../dag.py Nov 14, 2025
@Dev-iL
Copy link
Collaborator Author

Dev-iL commented Nov 16, 2025

@vincbeck I am unable to get these changes approved - giving up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants