Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion task-sdk/src/airflow/sdk/bases/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def get_all(
:return: List of all XCom values if found.
"""
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
from airflow.serialization.serde import deserialize

msg = SUPERVISOR_COMMS.send(
msg=GetXComSequenceSlice(
Expand All @@ -306,7 +307,10 @@ def get_all(
if not isinstance(msg, XComSequenceSliceResult):
raise TypeError(f"Expected XComSequenceSliceResult, received: {type(msg)} {msg}")

return msg.root
result = deserialize(msg.root)
if not result:
return None
return result

@staticmethod
def serialize_value(
Expand Down
12 changes: 7 additions & 5 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,17 +349,19 @@ def xcom_pull(

# If map_indexes is not specified, pull xcoms from all map indexes for each task
if isinstance(map_indexes, ArgNotSet):
xcoms = [
value
for t_id in task_ids
for value in XCom.get_all(
xcoms: list[Any] = []
for t_id in task_ids:
values = XCom.get_all(
run_id=run_id,
key=key,
task_id=t_id,
dag_id=dag_id,
)
]

if values is None:
xcoms.append(None)
else:
xcoms.extend(values)
# For single task pulling from unmapped task, return single value
if single_task_requested and len(xcoms) == 1:
return xcoms[0]
Expand Down
53 changes: 53 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,7 @@ def test_get_variable_from_context(
pytest.param("hello", id="string_value"),
pytest.param("'hello'", id="quoted_string_value"),
pytest.param({"key": "value"}, id="json_value"),
pytest.param([], id="empty_list_no_xcoms_found"),
pytest.param((1, 2, 3), id="tuple_int_value"),
pytest.param([1, 2, 3], id="list_int_value"),
pytest.param(42, id="int_value"),
Expand All @@ -1376,6 +1377,9 @@ def test_xcom_pull(
"""
map_indexes_kwarg = {} if map_indexes is NOTSET else {"map_indexes": map_indexes}
task_ids_kwarg = {} if task_ids is NOTSET else {"task_ids": task_ids}
from airflow.serialization.serde import deserialize

spy_agency.spy_on(deserialize)

class CustomOperator(BaseOperator):
def execute(self, context):
Expand All @@ -1401,6 +1405,7 @@ def mock_send_side_effect(*args, **kwargs):
mock_supervisor_comms.send.side_effect = mock_send_side_effect

run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())
spy_agency.assert_spy_called_with(deserialize, ser_value)

if not isinstance(task_ids, Iterable) or isinstance(task_ids, str):
task_ids = [task_ids]
Expand Down Expand Up @@ -1506,6 +1511,54 @@ def mock_get_all_side_effect(task_id, **kwargs):
assert mock_get_one.called
assert not mock_get_all.called

@pytest.mark.parametrize(
"api_return_value",
[
pytest.param(("data", "test_value"), id="api returns tuple"),
pytest.param({"data": "test_value"}, id="api returns dict"),
pytest.param(None, id="api returns None, no xcom found"),
],
)
def test_xcom_pull_with_no_map_index(
self,
api_return_value,
create_runtime_ti,
mock_supervisor_comms,
):
"""
Test xcom_pull when map_indexes is not specified, so that XCom.get_all is called.
The test also tests if the response is deserialized and returned.
"""
test_task_id = "pull_task"
task = BaseOperator(task_id=test_task_id)
runtime_ti = create_runtime_ti(task=task)

ser_value = BaseXCom.serialize_value(api_return_value)

def mock_send_side_effect(*args, **kwargs):
msg = kwargs.get("msg") or args[0]
if isinstance(msg, GetXComSequenceSlice):
return XComSequenceSliceResult(root=[ser_value])
return XComResult(key="test_key", value=None)

mock_supervisor_comms.send.side_effect = mock_send_side_effect
result = runtime_ti.xcom_pull(key="test_key", task_ids="task_a")

# if the API returns a tuple or dict, the below assertion assures that the value is deserialized correctly by XCom.get_all
assert result == api_return_value

mock_supervisor_comms.send.assert_called_once_with(
msg=GetXComSequenceSlice(
key="test_key",
dag_id=runtime_ti.dag_id,
run_id=runtime_ti.run_id,
task_id="task_a",
start=None,
stop=None,
step=None,
),
)

def test_get_param_from_context(
self, mocked_parse, make_ti_context, mock_supervisor_comms, create_runtime_ti
):
Expand Down