Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
a0cc92f
refactor: Implemented filter operation on XCom
davidblain-infrabel Apr 7, 2025
5c56f4e
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 7, 2025
031c28e
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 7, 2025
c858c9c
refactor: Fixed some static checks
davidblain-infrabel Apr 7, 2025
cc5e7be
refactor: Fixed some mypy issues
davidblain-infrabel Apr 7, 2025
48310f3
refactor: Added filter to PlainXComArg
davidblain-infrabel Apr 7, 2025
567c84b
refactor: Reverted SchedulerZipXComArg back to orginal
davidblain-infrabel Apr 7, 2025
22ede6a
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 7, 2025
49f2479
refactor: Fixed signature of filter method in PlainXComArg
davidblain-infrabel Apr 7, 2025
cfeb4a1
refactor: Fixed method signature filter
davidblain-infrabel Apr 7, 2025
4c4ee72
refactor: Fixed callables type in _FilterResult
davidblain-infrabel Apr 7, 2025
e5358aa
refactor: raise TypeError if getitem is called on iterable
davidblain-infrabel Apr 7, 2025
25640d8
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 7, 2025
deb60a7
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 7, 2025
b54b4ae
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 7, 2025
0b60a0f
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 7, 2025
faccd4b
refactor: Refactored _MapResult to support iterables
davidblain-infrabel Apr 7, 2025
bd204e0
refactor: Register task_map_length on SchedulerFilterXComArg
davidblain-infrabel Apr 7, 2025
9e9b50a
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 7, 2025
bd88a31
refactor: Fixed signature of __getitem__ in _MapResult
davidblain-infrabel Apr 8, 2025
b171639
refactor: Print filter callable result
davidblain-infrabel Apr 8, 2025
fd8a9f8
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 8, 2025
ce8959e
refactor: Changed calculation of length in FilterResult
davidblain-infrabel Apr 8, 2025
500df61
refactor: Fixed __getitem__ method of FilterResult
davidblain-infrabel Apr 8, 2025
d4adc17
refactor: No need to store result of apply_callables in __getitem__ m…
davidblain-infrabel Apr 8, 2025
800c087
refactor: Cache result of filtered values in FilterResult
davidblain-infrabel Apr 8, 2025
d57c288
refactor: Refactored the _FilterResult with cache and lazy evaluation…
davidblain-infrabel Apr 8, 2025
d33dd08
refactor: Harmonize the callables typing for filter and map methods
davidblain-infrabel Apr 8, 2025
ad3c128
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 8, 2025
4c8bcad
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 8, 2025
0ae23e0
refactor: Fixed some typings for Map and FilterXComArgs
davidblain-infrabel Apr 8, 2025
28aeb78
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 8, 2025
ab1e9b6
refactor: Fixed docstrings
davidblain-infrabel Apr 8, 2025
5d92e70
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 8, 2025
c26068a
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 10, 2025
34d9176
refactor: Introduced CallableResultMixin and splitted _MapResult into…
davidblain-infrabel Apr 15, 2025
e708941
refactor: Refactored CallableResultMixin and made it abstract class
davidblain-infrabel Apr 15, 2025
3a5a199
refactor: Added test case when value is dict instead of list
davidblain-infrabel Apr 15, 2025
f83d2ac
refactor: Fixed conversion of dict to list in CallableResultMixin
davidblain-infrabel Apr 16, 2025
f828b7f
refactor: Added mapping as bounded type
davidblain-infrabel Apr 16, 2025
7fac452
refactor: Ignore type check on XComResult
davidblain-infrabel Apr 16, 2025
04fb807
refactor: Changed elif to if in resolved method of MapXComArg
davidblain-infrabel Apr 16, 2025
3d23e11
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 16, 2025
ec273bf
refactor: Explicitly convert value to Sequence in CallableResultMixin
davidblain-infrabel Apr 16, 2025
800e8ea
refactor: Simplified _LazyMapResult and _FilterResult
davidblain-infrabel Apr 16, 2025
eb28ffe
refactor: Re-used comon values variable where possible in test xcom args
davidblain-infrabel Apr 16, 2025
0c415db
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 16, 2025
844b785
refactor: Changed types of results classes
davidblain-infrabel Apr 16, 2025
4e2bda9
refactor: Raise an ValueError if value isn't list, set or dict
davidblain-infrabel Apr 16, 2025
7bd7b93
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 16, 2025
68a8329
refactor: Sets musts be converted to lists also otherwise it's not in…
davidblain-infrabel Apr 16, 2025
d0835a6
refactor: Try except the StopIteration when yielding instead of suppress
davidblain-infrabel Apr 16, 2025
91adef4
refactor: Fixed __getitem__ magic method of _FilterResult
davidblain-infrabel Apr 16, 2025
40191f8
refactor: Renamed CallableResultMixin to _MappableResult and refactor…
davidblain-infrabel Apr 17, 2025
dec8eb8
refactor: Check if result in PlainXComArg needs runtime resolution
davidblain-infrabel Apr 17, 2025
e586ed0
refactor: Refactored _LazyMapResult and _FilterResult with __next__ m…
davidblain-infrabel Apr 17, 2025
3c37423
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 17, 2025
7a9c487
refactor: Changed non_filter method of FilterXComArg to staticmethod
davidblain-infrabel Apr 17, 2025
cedd8ec
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Apr 28, 2025
2e93ec4
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla May 19, 2025
ff703ac
Merge branch 'main' into feature/added-filter-operation-to-xcom
dabla Jun 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions airflow-core/src/airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self:
)


@attrs.define
class SchedulerFilterXComArg(SchedulerXComArg):
arg: SchedulerXComArg
callables: Sequence[str]

@classmethod
def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self:
# We are deliberately NOT deserializing the callables. These are shown
# in the UI, and displaying a function object is useless.
return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"])


@singledispatch
def get_task_map_length(xcom_arg: SchedulerXComArg, run_id: str, *, session: Session) -> int | None:
# The base implementation -- specific XComArg subclasses have specialised implementations
Expand Down Expand Up @@ -178,6 +190,11 @@ def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session):
return sum(ready_lengths)


@get_task_map_length.register
def _(xcom_arg: SchedulerFilterXComArg, run_id: str, *, session: Session):
return get_task_map_length(xcom_arg.arg, run_id, session=session)


def deserialize_xcom_arg(data: dict[str, Any], dag: SchedulerDAG):
"""DAG serialization interface."""
klass = _XCOM_ARG_TYPES[data.get("type", "")]
Expand All @@ -187,6 +204,7 @@ def deserialize_xcom_arg(data: dict[str, Any], dag: SchedulerDAG):
_XCOM_ARG_TYPES: dict[str, type[SchedulerXComArg]] = {
"": SchedulerPlainXComArg,
"concat": SchedulerConcatXComArg,
"filter": SchedulerFilterXComArg,
"map": SchedulerMapXComArg,
"zip": SchedulerZipXComArg,
}
202 changes: 190 additions & 12 deletions task-sdk/src/airflow/sdk/definitions/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

from __future__ import annotations

import contextlib
import inspect
import itertools
from abc import ABCMeta, abstractmethod
from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized
from contextlib import suppress
from functools import singledispatch
from typing import TYPE_CHECKING, Any, Callable, overload
from typing import TYPE_CHECKING, Any, Callable, overload, _T_co

from airflow.exceptions import AirflowException, XComNotFound
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
Expand All @@ -42,6 +43,7 @@
# the user, but deserialize them into strings in a serialized XComArg for
# safety (those callables are arbitrary user code).
MapCallables = Sequence[Callable[[Any], Any]]
FilterCallables = Sequence[Callable[[Any], bool]]


class XComArg(ResolveMixin, DependencyMixin):
Expand Down Expand Up @@ -175,6 +177,9 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
def concat(self, *others: XComArg) -> ConcatXComArg:
return ConcatXComArg([self, *others])

def filter(self, f: Callable[[Any], bool] | None) -> FilterXComArg:
return FilterXComArg(self, [f] if f else [])

def resolve(self, context: Mapping[str, Any]) -> Any:
raise NotImplementedError()

Expand Down Expand Up @@ -332,6 +337,11 @@ def concat(self, *others: XComArg) -> ConcatXComArg:
raise ValueError("cannot concatenate non-return XCom")
return super().concat(*others)

def filter(self, f: Callable[[Any], bool] | None) -> FilterXComArg:
if self.key != XCOM_RETURN_KEY:
raise ValueError("cannot filter non-return XCom")
return super().filter(f)

def resolve(self, context: Mapping[str, Any]) -> Any:
ti = context["ti"]
task_id = self.operator.task_id
Expand All @@ -351,6 +361,8 @@ def resolve(self, context: Mapping[str, Any]) -> Any:
map_indexes=map_indexes,
)
if not isinstance(result, ArgNotSet):
if isinstance(result, ResolveMixin):
result = result.resolve(context)
return result
if self.key == XCOM_RETURN_KEY:
return None
Expand All @@ -370,29 +382,87 @@ def _get_callable_name(f: Callable | str) -> str:
return f.__name__
# Parse the source to find whatever is behind "def". For safety, we don't
# want to evaluate the code in any meaningful way!
with contextlib.suppress(Exception):
with suppress(Exception):
kw, name, _ = f.lstrip().split(None, 2)
if kw == "def":
return name
return "<function>"


class _MapResult(Sequence):
def __init__(self, value: Sequence | dict, callables: MapCallables) -> None:
self.value = value
class _MappableResult(Sequence):
def __init__(self, value: Sequence | dict, callables: FilterCallables | MapCallables) -> None:
self.value = self._convert(value)
self.callables = callables

def __getitem__(self, index: Any) -> Any:
value = self.value[index]
raise NotImplementedError

def __len__(self) -> int:
raise NotImplementedError

for f in self.callables:
value = f(value)
@staticmethod
def _convert(value: Sequence | dict) -> list:
if isinstance(value, (dict, set)):
return list(value)
if isinstance(value, list):
return value
raise ValueError(
f"XCom filter expects sequence or dict, not {type(value).__name__}"
)

def _apply_callables(self, value) -> Any:
for func in self.callables:
value = func(value)
return value


class _MapResult(_MappableResult):
def __getitem__(self, index: Any) -> Any:
value = self._apply_callables(self.value[index])
return value

def __len__(self) -> int:
return len(self.value)


class _LazyMapResult(_MappableResult):
def __init__(self, value: Iterable, callables: MapCallables) -> None:
super().__init__([], callables)
self._iterator = iter(value)

def __next__(self) -> Any:
value = self._apply_callables(next(self._iterator))
self.value.append(value)
return value

def __getitem__(self, index: Any) -> Any:
if index < 0:
raise IndexError

while len(self.value) <= index:
try:
next(self)
except StopIteration:
raise IndexError
return self.value[index]

def __len__(self) -> int:
while True:
try:
next(self)
except StopIteration:
break
return len(self.value)

def __iter__(self) -> Iterator:
yield from self.value
while True:
try:
yield next(self)
except StopIteration:
break


class MapXComArg(XComArg):
"""
An XCom reference with ``map()`` call(s) applied.
Expand Down Expand Up @@ -429,9 +499,11 @@ def map(self, f: Callable[[Any], Any]) -> MapXComArg:

def resolve(self, context: Mapping[str, Any]) -> Any:
value = self.arg.resolve(context)
if not isinstance(value, (Sequence, dict)):
raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
return _MapResult(value, self.callables)
if isinstance(value, (Sequence, dict)):
return _MapResult(value, self.callables)
if isinstance(value, Iterable):
return _LazyMapResult(value, self.callables)
raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")


class _ZipResult(Sequence):
Expand Down Expand Up @@ -562,9 +634,110 @@ def resolve(self, context: Mapping[str, Any]) -> Any:
return _ConcatResult(values)


class _FilterResult(_MappableResult):
def __init__(self, value: Iterable, callables: FilterCallables) -> None:
super().__init__([], callables)
self._iterator = iter(value)

def __next__(self) -> Any:
while True:
value = next(self._iterator)
if self._apply_callables(value):
self.value.append(value)
return value

def __getitem__(self, index: int) -> Any:
if index < 0:
raise IndexError

while len(self.value) <= index:
try:
next(self)
except StopIteration:
break

return self.value[index]

def __len__(self) -> int:
while True:
try:
next(self)
except StopIteration:
break
return len(self.value)

def __iter__(self) -> Iterator:
yield from self.value
while True:
try:
yield next(self)
except StopIteration:
break

def _apply_callables(self, value) -> bool:
for func in self.callables:
if not func(value):
return False
return True


class FilterXComArg(XComArg):
"""
An XCom reference with ``filter()`` call(s) applied.

This is based on an XComArg, but also applies a series of "filters" that
filters the pulled XCom value.

:meta private:
"""

def __init__(
self,
arg: XComArg,
callables: FilterCallables | None,
) -> None:
self.arg = arg

if not callables:
callables = [self.none_filter]
else:
for c in callables:
if getattr(c, "_airflow_is_task_decorator", False):
raise ValueError("filter() argument must be a plain function, not a @task operator")
self.callables = callables

@staticmethod
def none_filter(value) -> bool:
return value if True else False

def __repr__(self) -> str:
map_calls = "".join(f".filter({_get_callable_name(f)})" for f in self.callables)
return f"{self.arg!r}{map_calls}"

def _serialize(self) -> dict[str, Any]:
return {
"arg": serialize_xcom_arg(self.arg),
"callables": [inspect.getsource(c) if callable(c) else c for c in self.callables],
}

def iter_references(self) -> Iterator[tuple[Operator, str]]:
yield from self.arg.iter_references()

def filter(self, f: Callable[[Any], bool] | None) -> FilterXComArg:
# Filter arg.filter(f1).filter(f2) into one FilterXComArg.
return FilterXComArg(self.arg, [*self.callables, f if f else self.none_filter])

def resolve(self, context: Mapping[str, Any]) -> Any:
value = self.arg.resolve(context)
if not isinstance(value, (Iterable, Sequence, dict)):
raise ValueError(f"XCom filter expects sequence or dict, not {type(value).__name__}")
return _FilterResult(value, self.callables)


_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = {
"": PlainXComArg,
"concat": ConcatXComArg,
"filter": FilterXComArg,
"map": MapXComArg,
"zip": ZipXComArg,
}
Expand Down Expand Up @@ -619,3 +792,8 @@ def _(xcom_arg: ConcatXComArg, resolved_val: Sized, upstream_map_indexes: dict[s
if len(ready_lengths) != len(xcom_arg.args):
return None # If any of the referenced XComs is not ready, we are not ready either.
return sum(ready_lengths)


@get_task_map_length.register
def _(xcom_arg: FilterXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]):
return get_task_map_length(xcom_arg.arg, resolved_val, upstream_map_indexes)
Loading
Loading