diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 7fdc53aa2e51e..bd78b5e42147f 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -17,6 +17,7 @@ from __future__ import annotations +import collections from typing import Any, Protocol import structlog @@ -30,6 +31,9 @@ XComSequenceSliceResult, ) +# Lightweight wrapper for XCom values +_XComValueWrapper = collections.namedtuple("_XComValueWrapper", "value") + log = structlog.get_logger(logger_name="task") @@ -296,7 +300,6 @@ def get_all( :return: List of all XCom values if found. """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - from airflow.serialization.serde import deserialize msg = SUPERVISOR_COMMS.send( msg=GetXComSequenceSlice( @@ -314,10 +317,10 @@ def get_all( if not isinstance(msg, XComSequenceSliceResult): raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") - result = deserialize(msg.root) - if not result: + if not msg.root: return None - return result + + return [cls.deserialize_value(_XComValueWrapper(value)) for value in msg.root] @staticmethod def serialize_value( diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 2ba524090a361..056fd4c4b9b03 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2094,6 +2094,30 @@ def execute(self, context): for x in mock_supervisor_comms.send.call_args_list ) + def test_get_all_uses_custom_deserialize_value(self, mock_supervisor_comms): + """ + Tests that XCom.get_all() calls the custom deserialize_value method. + """ + + class CustomXCom(BaseXCom): + @classmethod + def deserialize_value(cls, result): + """Custom deserialization that adds a prefix to show it was called.""" + original_value = super().deserialize_value(result) + return f"from custom xcom deserialize:{original_value}" + + serialized_values = ["value1", "value2", "value3"] + mock_supervisor_comms.send.return_value = XComSequenceSliceResult(root=serialized_values) + + result = CustomXCom.get_all(key="test_key", dag_id="test_dag", task_id="test_task", run_id="test_run") + + expected = [ + "from custom xcom deserialize:value1", + "from custom xcom deserialize:value2", + "from custom xcom deserialize:value3", + ] + assert result == expected + @pytest.mark.parametrize( ("include_prior_dates", "expected_value"), [