-
Notifications
You must be signed in to change notification settings - Fork 16.3k
Fix mypy violations in task-sdk/.../dag.py
#57952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
9ab5dca to
826b7cc
Compare
|
@Dev-iL This PR title doesn't make anysense to me though --- the TaskSDK has nothing to do with SQLa2. It (definitionally) can't load things from the DB... |
ashb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too many changes for just "mypy violations" to my mind. Introducing serialization when there was none before is well beyond that.
826b7cc to
23d2ed4
Compare
To my understanding, mypy behaves differently in the presence of SQLA2 (https://docs.sqlalchemy.org/en/14/orm/extensions/mypy.html), and this happens to affect some type hints in task-sdk. Regardless, I simplified this PR greatly - please see if it's acceptable now. |
8b8f1db to
083ef8f
Compare
|
I guess MyPy will assume (with bad heuristics) that UUID is coming somewhere from SQLAlchemy. MyPy sometimes makes wrong guesses. |
083ef8f to
1703f84
Compare
|
Rebased - there were problems caused by pytest 9. |
| ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs | ||
| next_kwargs_value = {"event": event.payload} if event else msg.next_kwargs | ||
| ti.next_kwargs = ( | ||
| next_kwargs_value | ||
| if isinstance(next_kwargs_value, dict) or next_kwargs_value is None | ||
| else None | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@uranusjr Where would be the appropriate place to fix the root cause of this issue, if not in this file?
task-sdk/src/airflow/sdk/definitions/dag.py:1418: error: Incompatible types in assignment (expression has type
"dict[str, Any] | str | None", variable has type "SQLCoreOperations[dict[Any, Any] | None] | dict[Any, Any] | None") [assignment]
ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs
^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1a06efe to
3f24a02
Compare
2e61997 to
7c0b7ce
Compare
7c0b7ce to
9477864
Compare
|
@uranusjr are we good to merge this now? |
|
Please change the pr title -- task sdk cant have sqla2 related changes since it doesnt use sqla2 |
| next_kwargs_value = {"event": event.payload} if event else msg.next_kwargs | ||
| if isinstance(next_kwargs_value, str): | ||
| from airflow.serialization.serialized_objects import BaseSerialization | ||
|
|
||
| ti.next_kwargs = BaseSerialization.deserialize(next_kwargs_value) | ||
| else: | ||
| ti.next_kwargs = next_kwargs_value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why this is needed? it feels like a behaviour change, not just a typing change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for the change is the following mypy violation (in the presence of SQLA2):
task-sdk/src/airflow/sdk/definitions/dag.py:1417: error: Incompatible types in assignment (expression has type
"dict[str, Any] | str | None", variable has type "SQLCoreOperations[dict[Any, Any] | None] | dict[Any, Any] | None") [assignment]
ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs
^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The explanation of the change is here (see "The Issue" and "The Correct Solution" sections).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What error was this showing without this change? This code appears to have been working fine as triggers are functional, (and I'm reasonably sure they are currently being encrypted as strings), so I'm not clear why this is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's because the type on the model is wrong then: SQLCoreOperations[dict[Any, Any] | None] should be dict[str,Any], not dict[Any, Any]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mypy has an issue with the | str | case of
# airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py#331
next_kwargs: dict[str, Any] | str | None = None
"""
Args to pass to ``next_method``.
Can either be a "decorated" dict, or a string encrypted with the shared Fernet key.
"""not the dict. TP remarked that perhaps this should never be a dict at all.
I let the LLM analyze it some more and it came up with this:
## The Issue
The original mypy error occurred because `ti.next_kwargs` expects `dict | None`, but `msg.next_kwargs` from `DeferTask` can be `dict | str | None` (where the `str` represents an encrypted value).
## The Solution
Update the deserialization logic in `dag.py` to properly handle all cases:
- None → assign `None` directly
- Serialized dict (with `__type` and `__var` keys) → call `BaseSerialization.deserialize()`
- Regular dict → use as-is
- Encrypted string → log warning and assign `None` (encryption not supported in `dag.test()` context)
## Key Discovery
`BaseSerialization.serialize()` returns a dict with special keys (not a JSON string), and `deserialize()` expects this dict format. The "encrypted string" case mentioned in comments refers to when Fernet encryption is enabled, which doesn't apply to `dag.test()`.So now it thinks the code should instead look like this:
# Deserialize next_kwargs if needed, matching what the API server does
next_kwargs_value = {"event": event.payload} if event else msg.next_kwargs
if next_kwargs_value is None:
ti.next_kwargs = None
elif isinstance(next_kwargs_value, dict):
if set(next_kwargs_value.keys()) == {"__type", "__var"}:
# Serialized format - deserialize it
from airflow.serialization.serialized_objects import BaseSerialization
ti.next_kwargs = BaseSerialization.deserialize(next_kwargs_value)
else:
# Regular dict - use as-is
ti.next_kwargs = next_kwargs_value
else:
# String (encrypted) - in dag.test() context, encryption is not used,
# but we need to handle this for type checking
# The API server would decrypt this before calling deserialize
if not isinstance(next_kwargs_value, str):
raise TypeError(f"Unexpected type for next_kwargs: {type(next_kwargs_value)}")
# For now, we can't decrypt without the Fernet key, so log a warning
log.warning(
"[DAG TEST] Received encrypted next_kwargs string, cannot decrypt in dag.test() context"
)
# Type-ignore needed because we can't decrypt in dag.test()
ti.next_kwargs = None # type: ignore[assignment]...with a couple of accompanying unit tests that illustrate the different behavior of the existing and proposed implementations:
class TestDeferredTaskNextKwargs:
"""Test that next_kwargs is properly deserialized when it's an encrypted string."""
@pytest.fixture
def mock_task_instance(self, mocker):
"""Create a mock scheduler TaskInstance."""
from airflow.sdk import TaskInstanceState
ti = mocker.MagicMock()
ti.task_id = "test_task"
ti.dag_id = "test_dag"
ti.run_id = "test_run"
ti.map_index = -1
ti.id = "123e4567-e89b-12d3-a456-426614174000"
ti.dag_version_id = "223e4567-e89b-12d3-a456-426614174000"
ti.try_number = 1
ti.state = TaskInstanceState.DEFERRED
return ti
@pytest.fixture
def mock_task(self, mocker):
"""Create a mock task."""
task = mocker.MagicMock()
task.task_id = "test_task"
return task
def test_next_kwargs_deserialized_when_encrypted_string(
self, mock_task_instance, mock_task, monkeypatch, mocker,
):
"""
Test that when msg.next_kwargs is a serialized string, it gets properly deserialized to a dict.
This simulates the case where DeferTask contains an encrypted/serialized next_kwargs.
The bug we're testing: without the fix, string next_kwargs would be assigned directly to ti.next_kwargs,
causing a type error. With the fix, it should be deserialized to a dict.
"""
from airflow.sdk.definitions.dag import _run_task
from airflow.sdk.execution_time.comms import DeferTask
from airflow.serialization.serialized_objects import BaseSerialization
# Setup: Create a dict that will be serialized (simulating real behavior)
original_kwargs = {"custom_param": "value", "count": 42}
# Serialize it to get the format with __type and __var keys
serialized_kwargs = BaseSerialization.serialize(original_kwargs)
# serialized_kwargs should now be a dict like:
# {'__type': 'dict', '__var': {'custom_param': 'value', 'count': 42}}
mock_defer_task = DeferTask(
classpath="airflow.triggers.base.BaseTrigger",
trigger_kwargs={},
next_method="execute_complete",
next_kwargs=serialized_kwargs,
)
# Create a mock TaskRunResult with the DeferTask message
mock_task_run_result = mocker.MagicMock()
mock_task_run_result.msg = mock_defer_task
mock_task_run_result.ti.state = "deferred"
mock_task_run_result.ti.task = mock_task
# Mock the run_task_in_process to return our prepared result
mock_run_task = mocker.MagicMock(return_value=mock_task_run_result)
monkeypatch.setattr(
"airflow.sdk.execution_time.supervisor.run_task_in_process",
mock_run_task,
raising=False,
)
# Mock _run_inline_trigger to return no event (so msg.next_kwargs is used)
mock_inline_trigger = mocker.MagicMock(return_value=None)
monkeypatch.setattr(
"airflow.sdk.definitions.dag._run_inline_trigger",
mock_inline_trigger,
raising=False,
)
# Mock create_scheduler_operator
monkeypatch.setattr(
"airflow.serialization.serialized_objects.create_scheduler_operator",
mocker.MagicMock(return_value=mock_task),
raising=False,
)
# Mock import_string to return a mock trigger
mock_trigger = mocker.MagicMock()
monkeypatch.setattr(
"airflow.sdk.module_loading.import_string",
mocker.MagicMock(return_value=lambda **kwargs: mock_trigger),
raising=False,
)
# Mock create_session
mock_session = mocker.MagicMock()
mock_create_session = mocker.MagicMock()
mock_create_session.__enter__ = mocker.MagicMock(return_value=mock_session)
mock_create_session.__exit__ = mocker.MagicMock(return_value=False)
monkeypatch.setattr(
"airflow.utils.session.create_session",
mocker.MagicMock(return_value=mock_create_session),
raising=False,
)
# Track assignments to next_kwargs using a property descriptor
assigned_values = []
original_next_kwargs = mock_task_instance.next_kwargs
def track_next_kwargs_setter(value):
assigned_values.append(value)
# Update the mock's return value for the property
type(mock_task_instance).next_kwargs = mocker.PropertyMock(return_value=value)
# Replace next_kwargs with a property that tracks assignments
type(mock_task_instance).next_kwargs = mocker.PropertyMock(
return_value=original_next_kwargs,
side_effect=lambda: assigned_values[-1] if assigned_values else original_next_kwargs
)
mocker.patch.object(
type(mock_task_instance),
'next_kwargs',
new_callable=mocker.PropertyMock,
return_value=original_next_kwargs
)
# We need to use __setattr__ to capture the assignment
original_setattr = type(mock_task_instance).__setattr__
def capturing_setattr(obj, name, value):
if name == 'next_kwargs':
assigned_values.append(value)
# Store it in the mock's internal state
object.__setattr__(obj, '_next_kwargs_value', value)
else:
original_setattr(obj, name, value)
type(mock_task_instance).__setattr__ = capturing_setattr
# Execute _run_task with run_triggerer=True to trigger the deferred path
_run_task(ti=mock_task_instance, task=mock_task, run_triggerer=True)
# Restore original behavior
type(mock_task_instance).__setattr__ = original_setattr
# Verify: The key behavior is that next_kwargs is now a dict, not a string
# This is what the fix ensures - regardless of HOW it's done
assert len(assigned_values) > 0, "next_kwargs should have been assigned at least once"
final_value = assigned_values[-1] # Get the last assigned value
assert isinstance(final_value, dict), (
f"next_kwargs should be a dict after deserialization, not {type(final_value).__name__}"
)
# Verify the dict contains the expected keys (proves it was properly deserialized)
assert "custom_param" in final_value
assert "count" in final_value
def test_next_kwargs_with_trigger_event(self, mock_task_instance, mock_task, monkeypatch, mocker):
"""
Test that when a trigger returns an event, event.payload is used for next_kwargs.
This verifies that the event path takes precedence over msg.next_kwargs.
"""
from airflow.sdk.definitions.dag import _run_task
from airflow.sdk.execution_time.comms import DeferTask
# Create a mock event with payload
mock_event = mocker.MagicMock()
mock_event.payload = {"event_data": "from_trigger", "timestamp": "2024-01-01"}
# Create mock DeferTask with a string next_kwargs that should be ignored
mock_defer_task = DeferTask(
classpath="airflow.triggers.base.BaseTrigger",
trigger_kwargs={},
next_method="execute_complete",
next_kwargs='{"should": "be_ignored"}', # Should be ignored since we have an event
)
# Create mock TaskRunResult
mock_task_run_result = mocker.MagicMock()
mock_task_run_result.msg = mock_defer_task
mock_task_run_result.ti.state = "deferred"
mock_task_run_result.ti.task = mock_task
# Mock dependencies
monkeypatch.setattr(
"airflow.sdk.execution_time.supervisor.run_task_in_process",
mocker.MagicMock(return_value=mock_task_run_result),
raising=False,
)
monkeypatch.setattr(
"airflow.sdk.definitions.dag._run_inline_trigger",
mocker.MagicMock(return_value=mock_event),
raising=False,
)
monkeypatch.setattr(
"airflow.serialization.serialized_objects.create_scheduler_operator",
mocker.MagicMock(return_value=mock_task),
raising=False,
)
monkeypatch.setattr(
"airflow.sdk.module_loading.import_string",
mocker.MagicMock(return_value=lambda **kwargs: mocker.MagicMock()),
raising=False,
)
mock_session = mocker.MagicMock()
mock_create_session = mocker.MagicMock()
mock_create_session.__enter__ = mocker.MagicMock(return_value=mock_session)
mock_create_session.__exit__ = mocker.MagicMock(return_value=False)
monkeypatch.setattr(
"airflow.utils.session.create_session",
mocker.MagicMock(return_value=mock_create_session),
raising=False,
)
# Execute
_run_task(ti=mock_task_instance, task=mock_task, run_triggerer=True)
# Verify: next_kwargs should be set to {"event": event.payload}, not msg.next_kwargs
assert mock_task_instance.next_kwargs == {"event": mock_event.payload}
assert isinstance(mock_task_instance.next_kwargs, dict)
# Ensure msg.next_kwargs was not used
assert mock_task_instance.next_kwargs.get("should") != "be_ignored"I'll revert my change to this bit since I don't see how to solve this in an acceptable way.
task-sdk/.../dag.pytask-sdk/.../dag.py
|
@vincbeck I am unable to get these changes approved - giving up. |

related: #56212, #58049
This PR fixes the below mypy violations that appear in task-sdk after upgrading to SQLA2:
^ Add meaningful description above
Read the Pull Request Guidelines for more information.
In case of fundamental code changes, an Airflow Improvement Proposal (AIP) is needed.
In case of a new dependency, check compliance with the ASF 3rd Party License Policy.
In case of backwards incompatible changes please leave a note in a newsfragment file, named
{pr_number}.significant.rstor{issue_number}.significant.rst, in airflow-core/newsfragments.