Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from collections.abc import Iterable, Iterator, Mapping
from datetime import datetime, timezone
from io import FileIO
from itertools import product
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar

Expand Down Expand Up @@ -277,12 +278,17 @@ def xcom_pull(
If *None* (default), the run_id of the calling task is used.

When pulling one single task (``task_id`` is *None* or a str) without
specifying ``map_indexes``, the return value is inferred from whether
the specified task is mapped. If not, value from the one single task
instance is returned. If the task to pull is mapped, an iterator (not a
list) yielding XComs from mapped task instances is returned. In either
case, ``default`` (*None* if not specified) is returned if no matching
XComs are found.
specifying ``map_indexes``, the return value is a single XCom entry
(map_indexes is set to map_index of the calling task instance).

When pulling task is mapped the specified ``map_index`` is used, so by default
pulling on mapped task will result in no matching XComs if the task instance
of the method call is not mapped. Otherwise, the map_index of the calling task
instance is used. Setting ``map_indexes`` to *None* will pull XCom as it would
from a non mapped task.

In either case, ``default`` (*None* if not specified) is returned if no
matching XComs are found.

When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is
a non-str iterable), a list of matching XComs is returned. Elements in
Expand All @@ -298,19 +304,32 @@ def xcom_pull(
task_ids = [self.task_id]
elif isinstance(task_ids, str):
task_ids = [task_ids]

map_indexes_iterable: Iterable[int | None] = []
# If map_indexes is not provided, default to use the map_index of the calling task
if isinstance(map_indexes, ArgNotSet):
map_indexes = self.map_index
map_indexes_iterable = [self.map_index]
elif isinstance(map_indexes, int) or map_indexes is None:
map_indexes_iterable = [map_indexes]
elif isinstance(map_indexes, Iterable):
# TODO: Handle multiple map_indexes or remove support
raise NotImplementedError("Multiple map_indexes are not supported yet")
map_indexes_iterable = map_indexes
else:
raise TypeError(
f"Invalid type for map_indexes: expected int, iterable of ints, or None, got {type(map_indexes)}"
)

xcoms = []
for t in task_ids:
# TODO: AIP 72 Execution API only allows working with a single map_index at a time
# this is inefficient and leads to task_id * map_index requests to the API.
# And we can't achieve the original behavior of XCom pull with multiple tasks
# directly now.
for t_id, m_idx in product(task_ids, map_indexes_iterable):
value = XCom.get_one(
run_id=run_id,
key=key,
task_id=t,
task_id=t_id,
dag_id=dag_id,
map_index=map_indexes,
map_index=m_idx,
)
xcoms.append(value if value else default)

Expand Down
72 changes: 50 additions & 22 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,31 +1183,53 @@ def test_get_variable_from_context(

assert var_from_context == Variable(key="test_key", value=expected_value)

@pytest.mark.parametrize(
"map_indexes",
[
pytest.param(-1, id="not_mapped_index"),
pytest.param(1, id="single_map_index"),
pytest.param([0, 1], id="multiple_map_indexes"),
pytest.param((0, 1), id="any_iterable_multi_indexes"),
pytest.param(None, id="index_none"),
pytest.param(NOTSET, id="index_not_set"),
],
)
@pytest.mark.parametrize(
"task_ids",
[
"push_task",
["push_task1", "push_task2"],
{"push_task1", "push_task2"},
None,
NOTSET,
pytest.param("push_task", id="single_task"),
pytest.param(["push_task1", "push_task2"], id="tid_multiple_tasks"),
pytest.param({"push_task1", "push_task2"}, id="tid_any_iterable"),
pytest.param(None, id="tid_none"),
pytest.param(NOTSET, id="tid_not_set"),
],
)
def test_xcom_pull(self, create_runtime_ti, mock_supervisor_comms, spy_agency, task_ids):
"""Test that a task pulls the expected XCom value if it exists."""
def test_xcom_pull(
self,
create_runtime_ti,
mock_supervisor_comms,
spy_agency,
task_ids,
map_indexes,
):
"""
Test that a task makes an expected call to the Supervisor to pull XCom values
based on various task_ids and map_indexes configurations.
"""
map_indexes_kwarg = {} if map_indexes is NOTSET else {"map_indexes": map_indexes}
task_ids_kwarg = {} if task_ids is NOTSET else {"task_ids": task_ids}

class CustomOperator(BaseOperator):
def execute(self, context):
if isinstance(task_ids, ArgNotSet):
value = context["ti"].xcom_pull(key="key")
else:
value = context["ti"].xcom_pull(task_ids=task_ids, key="key")
value = context["ti"].xcom_pull(key="key", **task_ids_kwarg, **map_indexes_kwarg)
print(f"Pulled XCom Value: {value}")

test_task_id = "pull_task"
task = CustomOperator(task_id=test_task_id)

runtime_ti = create_runtime_ti(task=task)
# In case of the specific map_index or None we should check it is passed to TI
extra_for_ti = {"map_index": map_indexes} if map_indexes in (1, None) else {}
runtime_ti = create_runtime_ti(task=task, **extra_for_ti)

mock_supervisor_comms.get_message.return_value = XComResult(key="key", value='"value"')

Expand All @@ -1216,20 +1238,26 @@ def execute(self, context):
if not isinstance(task_ids, Iterable) or isinstance(task_ids, str):
task_ids = [task_ids]

if not isinstance(map_indexes, Iterable):
map_indexes = [map_indexes]

for task_id in task_ids:
# Without task_ids (or None) expected behavior is to pull with calling task_id
if task_id is None or isinstance(task_id, ArgNotSet):
task_id = test_task_id
mock_supervisor_comms.send_request.assert_any_call(
log=mock.ANY,
msg=GetXCom(
key="key",
dag_id="test_dag",
run_id="test_run",
task_id=task_id,
map_index=-1,
),
)
for map_index in map_indexes:
if map_index == NOTSET:
map_index = -1
mock_supervisor_comms.send_request.assert_any_call(
log=mock.ANY,
msg=GetXCom(
key="key",
dag_id="test_dag",
run_id="test_run",
task_id=task_id,
map_index=map_index,
),
)

def test_get_param_from_context(
self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti
Expand Down