From a95c03aa75fd50da296eae52fd2957440b2af4ef Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 16 May 2025 16:01:01 +0530 Subject: [PATCH] Fix XCom deserialization for mapped tasks with custom backend --- .../sdk/execution_time/lazy_sequence.py | 5 ++- .../execution_time/test_lazy_sequence.py | 34 +++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 79822787f3881..0fbfcf39498e4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -140,10 +140,9 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: # if not reverse: # rows.reverse() # return rows - - from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.execution_time.comms import GetXComSequenceItem, XComResult from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.xcom import XCom with SUPERVISOR_COMMS.lock: source = (xcom_arg := self._xcom_arg).operator @@ -161,7 +160,7 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: if not isinstance(msg, XComResult): raise IndexError(key) - return BaseXCom.deserialize_value(msg) + return XCom.deserialize_value(msg) def _coerce_index(value: Any) -> int | None: diff --git a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py index a42572e5df1fa..2430f85f35e5e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py +++ b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py @@ -21,6 +21,8 @@ import pytest +import airflow +from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ( ErrorResponse, @@ -30,6 +32,9 @@ XComResult, ) from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence +from airflow.sdk.execution_time.xcom import resolve_xcom_backend + +from tests_common.test_utils.config import conf_vars @pytest.fixture @@ -52,6 +57,12 @@ def lazy_sequence(mock_xcom_arg, mock_ti): return LazyXComSequence(mock_xcom_arg, mock_ti) +class CustomXCom(BaseXCom): + @classmethod + def deserialize_value(cls, xcom): + return f"Made with CustomXCom: {xcom.value}" + + def test_len(mock_supervisor_comms, lazy_sequence): mock_supervisor_comms.get_message.return_value = XComCountResponse(len=3) assert len(lazy_sequence) == 3 @@ -109,6 +120,29 @@ def test_getitem_index(mock_supervisor_comms, lazy_sequence): ] +@conf_vars({("core", "xcom_backend"): "task_sdk.execution_time.test_lazy_sequence.CustomXCom"}) +def test_getitem_calls_correct_deserialise(mock_supervisor_comms, lazy_sequence): + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value="some-value") + + xcom = resolve_xcom_backend() + assert xcom.__name__ == "CustomXCom" + airflow.sdk.execution_time.xcom.XCom = xcom + + assert lazy_sequence[4] == "Made with CustomXCom: some-value" + assert mock_supervisor_comms.send_request.mock_calls == [ + call( + log=ANY, + msg=GetXComSequenceItem( + key="return_value", + dag_id="dag", + task_id="task", + run_id="run", + offset=4, + ), + ), + ] + + def test_getitem_indexerror(mock_supervisor_comms, lazy_sequence): mock_supervisor_comms.get_message.return_value = ErrorResponse( error=ErrorType.XCOM_NOT_FOUND,