diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 0c330652956f2..c9b777daca32e 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -21,7 +21,14 @@ import structlog -from airflow.sdk.execution_time.comms import DeleteXCom, GetXCom, SetXCom, XComResult +from airflow.sdk.execution_time.comms import ( + DeleteXCom, + GetXCom, + GetXComSequenceSlice, + SetXCom, + XComResult, + XComSequenceSliceResult, +) log = structlog.get_logger(logger_name="task") @@ -274,6 +281,56 @@ def get_one( ) return None + @classmethod + def get_all( + cls, + *, + key: str, + dag_id: str, + task_id: str, + run_id: str, + ) -> Any: + """ + Retrieve all XCom values for a task, typically from all map indexes. + + XComSequenceSliceResult can never have *None* in it, it returns an empty list + if no values were found. + + This is particularly useful for getting all XCom values from all map + indexes of a mapped task at once. + + :param key: A key for the XCom. Only XComs with this key will be returned. + :param run_id: DAG run ID for the task. + :param dag_id: DAG ID to pull XComs from. + :param task_id: Task ID to pull XComs from. + :return: List of all XCom values if found. + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + # Since Triggers can hit this code path via `sync_to_async` (which uses threads internally) + # we need to make sure that we "atomically" send a request and get the response to that + # back so that two triggers don't end up interleaving requests and create a possible + # race condition where the wrong trigger reads the response. + with SUPERVISOR_COMMS.lock: + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXComSequenceSlice( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + start=None, + stop=None, + step=None, + ), + ) + msg = SUPERVISOR_COMMS.get_message() + + if not isinstance(msg, XComSequenceSliceResult): + raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}") + + return msg.root + @staticmethod def serialize_value( value: Any, diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index bbc7394a87a0d..2374520e63f21 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -338,7 +338,7 @@ def xcom_pull( run_id = self.run_id single_task_requested = isinstance(task_ids, (str, type(None))) - single_map_index_requested = isinstance(map_indexes, (int, type(None), ArgNotSet)) + single_map_index_requested = isinstance(map_indexes, (int, type(None))) if task_ids is None: # default to the current task if not provided @@ -346,11 +346,27 @@ def xcom_pull( elif isinstance(task_ids, str): task_ids = [task_ids] - map_indexes_iterable: Iterable[int | None] = [] - # If map_indexes is not provided, default to use the map_index of the calling task + # If map_indexes is not specified, pull xcoms from all map indexes for each task if isinstance(map_indexes, ArgNotSet): - map_indexes_iterable = [self.map_index] - elif isinstance(map_indexes, int) or map_indexes is None: + xcoms = [ + value + for t_id in task_ids + for value in XCom.get_all( + run_id=run_id, + key=key, + task_id=t_id, + dag_id=dag_id, + ) + ] + + # For single task pulling from unmapped task, return single value + if single_task_requested and len(xcoms) == 1: + return xcoms[0] + return xcoms + + # Original logic when map_indexes is explicitly specified + map_indexes_iterable: Iterable[int | None] = [] + if isinstance(map_indexes, int) or map_indexes is None: map_indexes_iterable = [map_indexes] elif isinstance(map_indexes, Iterable): map_indexes_iterable = map_indexes @@ -360,10 +376,6 @@ def xcom_pull( ) xcoms = [] - # TODO: AIP 72 Execution API only allows working with a single map_index at a time - # this is inefficient and leads to task_id * map_index requests to the API. - # And we can't achieve the original behavior of XCom pull with multiple tasks - # directly now. for t_id, m_idx in product(task_ids, map_indexes_iterable): value = XCom.get_one( run_id=run_id, diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index a994e995bace1..d786fdaa5b2bc 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -77,6 +77,7 @@ GetTICount, GetVariable, GetXCom, + GetXComSequenceSlice, OKResponse, PrevSuccessfulDagRunResult, SetRenderedFields, @@ -91,6 +92,7 @@ TriggerDagRun, VariableResult, XComResult, + XComSequenceSliceResult, ) from airflow.sdk.execution_time.context import ( ConnectionAccessor, @@ -1113,7 +1115,7 @@ def test_get_context_with_ti_context_from_server(self, create_runtime_ti, mock_s task = BaseOperator(task_id="hello") # Assume the context is sent from the API server - # `task_sdk/tests/api/test_client.py::test_task_instance_start` checks the context is received + # `task-sdk/tests/api/test_client.py::test_task_instance_start` checks the context is received # from the API server runtime_ti = create_runtime_ti(task=task, dag_id="basic_task") @@ -1387,7 +1389,17 @@ def execute(self, context): runtime_ti = create_runtime_ti(task=task, **extra_for_ti) ser_value = BaseXCom.serialize_value(xcom_values) - mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=ser_value) + + def mock_get_message_side_effect(*args, **kwargs): + calls = mock_supervisor_comms.send_request.call_args_list + if calls: + last_call = calls[-1] + msg = last_call[1]["msg"] + if isinstance(msg, GetXComSequenceSlice): + return XComSequenceSliceResult(root=[ser_value]) + return XComResult(key="key", value=ser_value) + + mock_supervisor_comms.get_message.side_effect = mock_get_message_side_effect run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) @@ -1403,17 +1415,30 @@ def execute(self, context): task_id = test_task_id for map_index in map_indexes: if map_index == NOTSET: - map_index = -1 - mock_supervisor_comms.send_request.assert_any_call( - log=mock.ANY, - msg=GetXCom( - key="key", - dag_id="test_dag", - run_id="test_run", - task_id=task_id, - map_index=map_index, - ), - ) + mock_supervisor_comms.send_request.assert_any_call( + log=mock.ANY, + msg=GetXComSequenceSlice( + key="key", + dag_id="test_dag", + run_id="test_run", + task_id=task_id, + start=None, + stop=None, + step=None, + ), + ) + else: + expected_map_index = map_index if map_index is not None else None + mock_supervisor_comms.send_request.assert_any_call( + log=mock.ANY, + msg=GetXCom( + key="key", + dag_id="test_dag", + run_id="test_run", + task_id=task_id, + map_index=expected_map_index, + ), + ) @pytest.mark.parametrize( "task_ids, map_indexes, expected_value", @@ -1421,7 +1446,6 @@ def execute(self, context): pytest.param("task_a", 0, {"a": 1, "b": 2}, id="task_id is str, map_index is int"), pytest.param("task_a", [0], [{"a": 1, "b": 2}], id="task_id is str, map_index is list"), pytest.param("task_a", None, {"a": 1, "b": 2}, id="task_id is str, map_index is None"), - pytest.param("task_a", NOTSET, {"a": 1, "b": 2}, id="task_id is str, map_index is ArgNotSet"), pytest.param(["task_a"], 0, [{"a": 1, "b": 2}], id="task_id is list, map_index is int"), pytest.param(["task_a"], [0], [{"a": 1, "b": 2}], id="task_id is list, map_index is list"), pytest.param(["task_a"], None, [{"a": 1, "b": 2}], id="task_id is list, map_index is None"), @@ -1431,6 +1455,13 @@ def execute(self, context): pytest.param(None, 0, {"a": 1, "b": 2}, id="task_id is None, map_index is int"), pytest.param(None, [0], [{"a": 1, "b": 2}], id="task_id is None, map_index is list"), pytest.param(None, None, {"a": 1, "b": 2}, id="task_id is None, map_index is None"), + pytest.param( + ["task_a", "task_b"], + NOTSET, + [{"a": 1, "b": 2}, {"c": 3, "d": 4}], + id="multiple task_ids, map_index is ArgNotSet", + ), + pytest.param("task_a", NOTSET, {"a": 1, "b": 2}, id="task_id is str, map_index is ArgNotSet"), pytest.param(None, NOTSET, {"a": 1, "b": 2}, id="task_id is None, map_index is ArgNotSet"), ], ) @@ -1444,7 +1475,7 @@ def test_xcom_pull_return_values( ): """ Tests return value of xcom_pull under various combinations of task_ids and map_indexes. - The above test covers the expected calls to supervisor comms. + Also verifies the correct XCom method (get_one vs get_all) is called. """ class CustomOperator(BaseOperator): @@ -1455,13 +1486,28 @@ def execute(self, context): task = CustomOperator(task_id=test_task_id) runtime_ti = create_runtime_ti(task=task) - value = {"a": 1, "b": 2} - # API server returns serialised value for xcom result, staging it in that way - xcom_value = BaseXCom.serialize_value(value) - mock_supervisor_comms.get_message.return_value = XComResult(key="key", value=xcom_value) - - returned_xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes) - assert returned_xcom == expected_value + with patch.object(XCom, "get_one") as mock_get_one, patch.object(XCom, "get_all") as mock_get_all: + if map_indexes == NOTSET: + # Use side_effect to return different values for different tasks + def mock_get_all_side_effect(task_id, **kwargs): + if task_id == "task_b": + return [{"c": 3, "d": 4}] + return [{"a": 1, "b": 2}] + + mock_get_all.side_effect = mock_get_all_side_effect + mock_get_one.return_value = None + else: + mock_get_one.return_value = {"a": 1, "b": 2} + mock_get_all.return_value = None + + xcom = runtime_ti.xcom_pull(key="key", task_ids=task_ids, map_indexes=map_indexes) + assert xcom == expected_value + if map_indexes == NOTSET: + assert mock_get_all.called + assert not mock_get_one.called + else: + assert mock_get_one.called + assert not mock_get_all.called def test_get_param_from_context( self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti @@ -1910,13 +1956,11 @@ def execute(self, context): runtime_ti = create_runtime_ti(task=task) run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) - mock_xcom_backend.get_one.assert_called_once_with( + mock_xcom_backend.get_all.assert_called_once_with( key="key", dag_id="test_dag", task_id="pull_task", run_id="test_run", - map_index=-1, - include_prior_dates=False, ) assert not any(