diff --git a/airflow-core/src/airflow/models/xcom_arg.py b/airflow-core/src/airflow/models/xcom_arg.py index cfda9295cec26..c7a7c75ba7140 100644 --- a/airflow-core/src/airflow/models/xcom_arg.py +++ b/airflow-core/src/airflow/models/xcom_arg.py @@ -104,6 +104,18 @@ def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: ) +@attrs.define +class SchedulerFilterXComArg(SchedulerXComArg): + arg: SchedulerXComArg + callables: Sequence[str] + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + # We are deliberately NOT deserializing the callables. These are shown + # in the UI, and displaying a function object is useless. + return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + + @singledispatch def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: Session) -> int | None: # The base implementation -- specific XComArg subclasses have specialised implementations @@ -178,6 +190,11 @@ def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session): return sum(ready_lengths) +@get_task_map_length.register +def _(xcom_arg: SchedulerFilterXComArg, run_id: str, *, session: Session): + return get_task_map_length(xcom_arg.arg, run_id, session=session) + + def deserialize_xcom_arg(data: dict[str, Any], dag: SchedulerDAG): """DAG serialization interface.""" klass = _XCOM_ARG_TYPES[data.get("type", "")] @@ -187,6 +204,7 @@ def deserialize_xcom_arg(data: dict[str, Any], dag: SchedulerDAG): _XCOM_ARG_TYPES: dict[str, type[SchedulerXComArg]] = { "": SchedulerPlainXComArg, "concat": SchedulerConcatXComArg, + "filter": SchedulerFilterXComArg, "map": SchedulerMapXComArg, "zip": SchedulerZipXComArg, } diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 2a93585304cb0..147b1736cf6dc 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -17,12 +17,13 @@ from __future__ import annotations -import contextlib import inspect import itertools +from abc import ABCMeta, abstractmethod from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from contextlib import suppress from functools import singledispatch -from typing import TYPE_CHECKING, Any, Callable, overload +from typing import TYPE_CHECKING, Any, Callable, overload, _T_co from airflow.exceptions import AirflowException, XComNotFound from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator @@ -42,6 +43,7 @@ # the user, but deserialize them into strings in a serialized XComArg for # safety (those callables are arbitrary user code). MapCallables = Sequence[Callable[[Any], Any]] +FilterCallables = Sequence[Callable[[Any], bool]] class XComArg(ResolveMixin, DependencyMixin): @@ -175,6 +177,9 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: def concat(self, *others: XComArg) -> ConcatXComArg: return ConcatXComArg([self, *others]) + def filter(self, f: Callable[[Any], bool] | None) -> FilterXComArg: + return FilterXComArg(self, [f] if f else []) + def resolve(self, context: Mapping[str, Any]) -> Any: raise NotImplementedError() @@ -332,6 +337,11 @@ def concat(self, *others: XComArg) -> ConcatXComArg: raise ValueError("cannot concatenate non-return XCom") return super().concat(*others) + def filter(self, f: Callable[[Any], bool] | None) -> FilterXComArg: + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot filter non-return XCom") + return super().filter(f) + def resolve(self, context: Mapping[str, Any]) -> Any: ti = context["ti"] task_id = self.operator.task_id @@ -351,6 +361,8 @@ def resolve(self, context: Mapping[str, Any]) -> Any: map_indexes=map_indexes, ) if not isinstance(result, ArgNotSet): + if isinstance(result, ResolveMixin): + result = result.resolve(context) return result if self.key == XCOM_RETURN_KEY: return None @@ -370,29 +382,87 @@ def _get_callable_name(f: Callable | str) -> str: return f.__name__ # Parse the source to find whatever is behind "def". For safety, we don't # want to evaluate the code in any meaningful way! - with contextlib.suppress(Exception): + with suppress(Exception): kw, name, _ = f.lstrip().split(None, 2) if kw == "def": return name return "" -class _MapResult(Sequence): - def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: - self.value = value +class _MappableResult(Sequence): + def __init__(self, value: Sequence | dict, callables: FilterCallables | MapCallables) -> None: + self.value = self._convert(value) self.callables = callables def __getitem__(self, index: Any) -> Any: - value = self.value[index] + raise NotImplementedError + + def __len__(self) -> int: + raise NotImplementedError - for f in self.callables: - value = f(value) + @staticmethod + def _convert(value: Sequence | dict) -> list: + if isinstance(value, (dict, set)): + return list(value) + if isinstance(value, list): + return value + raise ValueError( + f"XCom filter expects sequence or dict, not {type(value).__name__}" + ) + + def _apply_callables(self, value) -> Any: + for func in self.callables: + value = func(value) + return value + + +class _MapResult(_MappableResult): + def __getitem__(self, index: Any) -> Any: + value = self._apply_callables(self.value[index]) return value def __len__(self) -> int: return len(self.value) +class _LazyMapResult(_MappableResult): + def __init__(self, value: Iterable, callables: MapCallables) -> None: + super().__init__([], callables) + self._iterator = iter(value) + + def __next__(self) -> Any: + value = self._apply_callables(next(self._iterator)) + self.value.append(value) + return value + + def __getitem__(self, index: Any) -> Any: + if index < 0: + raise IndexError + + while len(self.value) <= index: + try: + next(self) + except StopIteration: + raise IndexError + return self.value[index] + + def __len__(self) -> int: + while True: + try: + next(self) + except StopIteration: + break + return len(self.value) + + def __iter__(self) -> Iterator: + yield from self.value + while True: + try: + yield next(self) + except StopIteration: + break + + class MapXComArg(XComArg): """ An XCom reference with ``map()`` call(s) applied. @@ -429,9 +499,11 @@ def map(self, f: Callable[[Any], Any]) -> MapXComArg: def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) - if not isinstance(value, (Sequence, dict)): - raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") - return _MapResult(value, self.callables) + if isinstance(value, (Sequence, dict)): + return _MapResult(value, self.callables) + if isinstance(value, Iterable): + return _LazyMapResult(value, self.callables) + raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") class _ZipResult(Sequence): @@ -562,9 +634,110 @@ def resolve(self, context: Mapping[str, Any]) -> Any: return _ConcatResult(values) +class _FilterResult(_MappableResult): + def __init__(self, value: Iterable, callables: FilterCallables) -> None: + super().__init__([], callables) + self._iterator = iter(value) + + def __next__(self) -> Any: + while True: + value = next(self._iterator) + if self._apply_callables(value): + self.value.append(value) + return value + + def __getitem__(self, index: int) -> Any: + if index < 0: + raise IndexError + + while len(self.value) <= index: + try: + next(self) + except StopIteration: + break + + return self.value[index] + + def __len__(self) -> int: + while True: + try: + next(self) + except StopIteration: + break + return len(self.value) + + def __iter__(self) -> Iterator: + yield from self.value + while True: + try: + yield next(self) + except StopIteration: + break + + def _apply_callables(self, value) -> bool: + for func in self.callables: + if not func(value): + return False + return True + + +class FilterXComArg(XComArg): + """ + An XCom reference with ``filter()`` call(s) applied. + + This is based on an XComArg, but also applies a series of "filters" that + filters the pulled XCom value. + + :meta private: + """ + + def __init__( + self, + arg: XComArg, + callables: FilterCallables | None, + ) -> None: + self.arg = arg + + if not callables: + callables = [self.none_filter] + else: + for c in callables: + if getattr(c, "_airflow_is_task_decorator", False): + raise ValueError("filter() argument must be a plain function, not a @task operator") + self.callables = callables + + @staticmethod + def none_filter(value) -> bool: + return value if True else False + + def __repr__(self) -> str: + map_calls = "".join(f".filter({_get_callable_name(f)})" for f in self.callables) + return f"{self.arg!r}{map_calls}" + + def _serialize(self) -> dict[str, Any]: + return { + "arg": serialize_xcom_arg(self.arg), + "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], + } + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield from self.arg.iter_references() + + def filter(self, f: Callable[[Any], bool] | None) -> FilterXComArg: + # Filter arg.filter(f1).filter(f2) into one FilterXComArg. + return FilterXComArg(self.arg, [*self.callables, f if f else self.none_filter]) + + def resolve(self, context: Mapping[str, Any]) -> Any: + value = self.arg.resolve(context) + if not isinstance(value, (Iterable, Sequence, dict)): + raise ValueError(f"XCom filter expects sequence or dict, not {type(value).__name__}") + return _FilterResult(value, self.callables) + + _XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = { "": PlainXComArg, "concat": ConcatXComArg, + "filter": FilterXComArg, "map": MapXComArg, "zip": ZipXComArg, } @@ -619,3 +792,8 @@ def _(xcom_arg: ConcatXComArg, resolved_val: Sized, upstream_map_indexes: dict[s if len(ready_lengths) != len(xcom_arg.args): return None # If any of the referenced XComs is not ready, we are not ready either. return sum(ready_lengths) + + +@get_task_map_length.register +def _(xcom_arg: FilterXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + return get_task_map_length(xcom_arg.arg, resolved_val, upstream_map_indexes) diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index d313b45dae265..b86037456043f 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -36,6 +36,8 @@ def test_xcom_map(run_ti: RunTI, mock_supervisor_comms): results = set() + values = ["a", "b", "c"] + with DAG("test") as dag: @dag.task @@ -52,7 +54,7 @@ def pull(value): assert set(dag.task_dict) == {"push", "pull"} # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) for map_index in range(3): assert run_ti(dag, "pull", map_index) == TaskInstanceState.SUCCESS @@ -60,44 +62,77 @@ def pull(value): assert results == {"aa", "bb", "cc"} -def test_xcom_map_transform_to_none(run_ti: RunTI, mock_supervisor_comms): +def test_xcom_map_transform_to_none_and_filter_on_list(run_ti: RunTI, mock_supervisor_comms): results = set() + values = ["a", "b", "c"] with DAG("test") as dag: @dag.task() def push(): - return ["a", "b", "c"] + return values @dag.task() def pull(value): results.add(value) - def c_to_none(v): - if v == "c": + def c_to_none(value): + if value == "c": return None - return v + return value - pull.expand(value=push().map(c_to_none)) + pull.expand(value=push().map(c_to_none).filter(None)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # Run "pull". This should automatically convert "c" to None. for map_index in range(3): assert run_ti(dag, "pull", map_index) == TaskInstanceState.SUCCESS - assert results == {"a", "b", None} + assert results == {"a", "b"} + + +def test_xcom_map_transform_to_none_and_filter_on_dict(run_ti: RunTI, mock_supervisor_comms): + results = set() + values = {"a": "alpha", "b": "beta", "c": "charly"} + + with DAG("test") as dag: + + @dag.task() + def push(): + return values + + @dag.task() + def pull(value): + results.add(value) + + def c_to_none(value): + if "c" in value: + return None + return value + + pull.expand(value=push().map(c_to_none).filter(None)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # type: ignore + + # Run "pull". This should automatically convert "c" to None. + for map_index in range(3): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + assert results == {"a", "b"} def test_xcom_convert_to_kwargs_fails_task(run_ti: RunTI, mock_supervisor_comms, captured_logs): results = set() + values = ["a", "b", "c"] with DAG("test") as dag: @dag.task() def push(): - return ["a", "b", "c"] + return values @dag.task() def pull(value): @@ -111,7 +146,7 @@ def c_to_none(v): pull.expand_kwargs(push().map(c_to_none)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # The first two "pull" tis should succeed. for map_index in range(2): @@ -147,11 +182,13 @@ def c_to_none(v): def test_xcom_map_error_fails_task(mock_supervisor_comms, run_ti, captured_logs): + values = ["a", "b", "c"] + with DAG("test") as dag: @dag.task() def push(): - return ["a", "b", "c"] + return values @dag.task() def pull(value): @@ -165,7 +202,7 @@ def does_not_work_with_c(v): pull.expand_kwargs(push().map(does_not_work_with_c)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # The third one (for "c") will fail. assert run_ti(dag, "pull", 2) == TaskInstanceState.FAILED @@ -194,12 +231,13 @@ def does_not_work_with_c(v): def test_xcom_map_nest(mock_supervisor_comms, run_ti): results = set() + values = ["a", "b", "c"] with DAG("test") as dag: @dag.task() def push(): - return ["a", "b", "c"] + return values @dag.task() def pull(value): @@ -209,7 +247,7 @@ def pull(value): pull.expand_kwargs(converted) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # Now "pull" should apply the mapping functions in order. for map_index in range(3): @@ -267,12 +305,13 @@ def xcom_get(): def test_xcom_map_raise_to_skip(run_ti, mock_supervisor_comms): result = [] + values = ["a", "b", "c"] with DAG("test") as dag: @dag.task() def push(): - return ["a", "b", "c"] + return values @dag.task() def forward(value): @@ -286,7 +325,7 @@ def skip_c(v): forward.expand_kwargs(push().map(skip_c)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # Run "forward". This should automatically skip "c". states = [run_ti(dag, "forward", map_index) for map_index in range(3)]