Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,9 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]:
# if not reverse:
# rows.reverse()
# return rows

from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.execution_time.comms import GetXComSequenceItem, XComResult
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
from airflow.sdk.execution_time.xcom import XCom

with SUPERVISOR_COMMS.lock:
source = (xcom_arg := self._xcom_arg).operator
Expand All @@ -161,7 +160,7 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]:

if not isinstance(msg, XComResult):
raise IndexError(key)
return BaseXCom.deserialize_value(msg)
return XCom.deserialize_value(msg)


def _coerce_index(value: Any) -> int | None:
Expand Down
34 changes: 34 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import pytest

import airflow
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
ErrorResponse,
Expand All @@ -30,6 +32,9 @@
XComResult,
)
from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence
from airflow.sdk.execution_time.xcom import resolve_xcom_backend

from tests_common.test_utils.config import conf_vars


@pytest.fixture
Expand All @@ -52,6 +57,12 @@ def lazy_sequence(mock_xcom_arg, mock_ti):
return LazyXComSequence(mock_xcom_arg, mock_ti)


class CustomXCom(BaseXCom):
@classmethod
def deserialize_value(cls, xcom):
return f"Made with CustomXCom: {xcom.value}"


def test_len(mock_supervisor_comms, lazy_sequence):
mock_supervisor_comms.get_message.return_value = XComCountResponse(len=3)
assert len(lazy_sequence) == 3
Expand Down Expand Up @@ -109,6 +120,29 @@ def test_getitem_index(mock_supervisor_comms, lazy_sequence):
]


@conf_vars({("core", "xcom_backend"): "task_sdk.execution_time.test_lazy_sequence.CustomXCom"})
def test_getitem_calls_correct_deserialise(mock_supervisor_comms, lazy_sequence):
mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value="some-value")

xcom = resolve_xcom_backend()
assert xcom.__name__ == "CustomXCom"
airflow.sdk.execution_time.xcom.XCom = xcom

assert lazy_sequence[4] == "Made with CustomXCom: some-value"
assert mock_supervisor_comms.send_request.mock_calls == [
call(
log=ANY,
msg=GetXComSequenceItem(
key="return_value",
dag_id="dag",
task_id="task",
run_id="run",
offset=4,
),
),
]


def test_getitem_indexerror(mock_supervisor_comms, lazy_sequence):
mock_supervisor_comms.get_message.return_value = ErrorResponse(
error=ErrorType.XCOM_NOT_FOUND,
Expand Down