diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 770dbf53df769..82df8d151ab13 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -290,6 +290,7 @@ def get_all( :return: List of all XCom values if found. """ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.serialization.serde import deserialize msg = SUPERVISOR_COMMS.send( msg=GetXComSequenceSlice( @@ -306,7 +307,10 @@ def get_all( if not isinstance(msg, XComSequenceSliceResult): raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") - return msg.root + result = deserialize(msg.root) + if not result: + return None + return result @staticmethod def serialize_value( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 5fa868b9b88bd..5abfcea226b20 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -349,17 +349,19 @@ def xcom_pull( # If map_indexes is not specified, pull xcoms from all map indexes for each task if isinstance(map_indexes, ArgNotSet): - xcoms = [ - value - for t_id in task_ids - for value in XCom.get_all( + xcoms: list[Any] = [] + for t_id in task_ids: + values = XCom.get_all( run_id=run_id, key=key, task_id=t_id, dag_id=dag_id, ) - ] + if values is None: + xcoms.append(None) + else: + xcoms.extend(values) # For single task pulling from unmapped task, return single value if single_task_requested and len(xcoms) == 1: return xcoms[0] diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 180068f801b8f..416fee39af9a3 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1354,6 +1354,7 @@ def test_get_variable_from_context( pytest.param("hello", id="string_value"), pytest.param("'hello'", id="quoted_string_value"), pytest.param({"key": "value"}, id="json_value"), + pytest.param([], id="empty_list_no_xcoms_found"), pytest.param((1, 2, 3), id="tuple_int_value"), pytest.param([1, 2, 3], id="list_int_value"), pytest.param(42, id="int_value"), @@ -1376,6 +1377,9 @@ def test_xcom_pull( """ map_indexes_kwarg = {} if map_indexes is NOTSET else {"map_indexes": map_indexes} task_ids_kwarg = {} if task_ids is NOTSET else {"task_ids": task_ids} + from airflow.serialization.serde import deserialize + + spy_agency.spy_on(deserialize) class CustomOperator(BaseOperator): def execute(self, context): @@ -1401,6 +1405,7 @@ def mock_send_side_effect(*args, **kwargs): mock_supervisor_comms.send.side_effect = mock_send_side_effect run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + spy_agency.assert_spy_called_with(deserialize, ser_value) if not isinstance(task_ids, Iterable) or isinstance(task_ids, str): task_ids = [task_ids] @@ -1506,6 +1511,54 @@ def mock_get_all_side_effect(task_id, **kwargs): assert mock_get_one.called assert not mock_get_all.called + @pytest.mark.parametrize( + "api_return_value", + [ + pytest.param(("data", "test_value"), id="api returns tuple"), + pytest.param({"data": "test_value"}, id="api returns dict"), + pytest.param(None, id="api returns None, no xcom found"), + ], + ) + def test_xcom_pull_with_no_map_index( + self, + api_return_value, + create_runtime_ti, + mock_supervisor_comms, + ): + """ + Test xcom_pull when map_indexes is not specified, so that XCom.get_all is called. + The test also tests if the response is deserialized and returned. + """ + test_task_id = "pull_task" + task = BaseOperator(task_id=test_task_id) + runtime_ti = create_runtime_ti(task=task) + + ser_value = BaseXCom.serialize_value(api_return_value) + + def mock_send_side_effect(*args, **kwargs): + msg = kwargs.get("msg") or args[0] + if isinstance(msg, GetXComSequenceSlice): + return XComSequenceSliceResult(root=[ser_value]) + return XComResult(key="test_key", value=None) + + mock_supervisor_comms.send.side_effect = mock_send_side_effect + result = runtime_ti.xcom_pull(key="test_key", task_ids="task_a") + + # if the API returns a tuple or dict, the below assertion assures that the value is deserialized correctly by XCom.get_all + assert result == api_return_value + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetXComSequenceSlice( + key="test_key", + dag_id=runtime_ti.dag_id, + run_id=runtime_ti.run_id, + task_id="task_a", + start=None, + stop=None, + step=None, + ), + ) + def test_get_param_from_context( self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti ):