Skip to content

Commit

Permalink
Implement XComArg concat() (apache#40172)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored and romsharon98 committed Jul 26, 2024
1 parent 2e39b19 commit 81e7593
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 6 deletions.
88 changes: 86 additions & 2 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import contextlib
import inspect
import itertools
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Mapping, Sequence, Union, overload

from sqlalchemy import func, or_, select
Expand Down Expand Up @@ -159,7 +160,7 @@ def set_downstream(

def _serialize(self) -> dict[str, Any]:
"""
Serialize a DAG.
Serialize an XComArg.
The implementation should be the inverse function to ``deserialize``,
returning a data dict converted from this XComArg derivative. DAG
Expand All @@ -172,7 +173,7 @@ def _serialize(self) -> dict[str, Any]:
@classmethod
def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
"""
Deserialize a DAG.
Deserialize an XComArg.
The implementation should be the inverse function to ``serialize``,
implementing given a data dict converted from this XComArg derivative,
Expand All @@ -189,6 +190,9 @@ def map(self, f: Callable[[Any], Any]) -> MapXComArg:
def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
return ZipXComArg([self, *others], fillvalue=fillvalue)

def concat(self, *others: XComArg) -> ConcatXComArg:
return ConcatXComArg([self, *others])

def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
"""Inspect length of pushed value for task-mapping.
Expand Down Expand Up @@ -411,6 +415,11 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
raise ValueError("cannot map against non-return XCom")
return super().zip(*others, fillvalue=fillvalue)

def concat(self, *others: XComArg) -> ConcatXComArg:
if self.key != XCOM_RETURN_KEY:
raise ValueError("cannot concatenate non-return XCom")
return super().concat(*others)

def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
return _get_task_map_length(
dag_id=self.operator.dag_id,
Expand Down Expand Up @@ -622,8 +631,83 @@ def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
return _ZipResult(values, fillvalue=self.fillvalue)


class _ConcatResult(Sequence):
def __init__(self, values: Sequence[Sequence | dict]) -> None:
self.values = values

def __getitem__(self, index: Any) -> Any:
if index >= 0:
i = index
else:
i = len(self) + index
for value in self.values:
if i < 0:
break
elif i >= (curlen := len(value)):
i -= curlen
elif isinstance(value, Sequence):
return value[i]
else:
return next(itertools.islice(iter(value), i, None))
raise IndexError("list index out of range")

def __len__(self) -> int:
return sum(len(v) for v in self.values)


class ConcatXComArg(XComArg):
"""Concatenating multiple XCom references into one.
This is done by calling ``concat()`` on an XComArg to combine it with
others. The effect is similar to Python's :func:`itertools.chain`, but the
return value also supports index access.
"""

def __init__(self, args: Sequence[XComArg]) -> None:
if not args:
raise ValueError("At least one input is required")
self.args = args

def __repr__(self) -> str:
args_iter = iter(self.args)
first = repr(next(args_iter))
rest = ", ".join(repr(arg) for arg in args_iter)
return f"{first}.concat({rest})"

def _serialize(self) -> dict[str, Any]:
return {"args": [serialize_xcom_arg(arg) for arg in self.args]}

@classmethod
def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
return cls([deserialize_xcom_arg(arg, dag) for arg in data["args"]])

def iter_references(self) -> Iterator[tuple[Operator, str]]:
for arg in self.args:
yield from arg.iter_references()

def concat(self, *others: XComArg) -> ConcatXComArg:
# Flatten foo.concat(x).concat(y) into one call.
return ConcatXComArg([*self.args, *others])

def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
ready_lengths = [length for length in all_lengths if length is not None]
if len(ready_lengths) != len(self.args):
return None # If any of the referenced XComs is not ready, we are not ready either.
return sum(ready_lengths)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
values = [arg.resolve(context, session=session) for arg in self.args]
for value in values:
if not isinstance(value, (Sequence, dict)):
raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}")
return _ConcatResult(values)


_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = {
"": PlainXComArg,
"concat": ConcatXComArg,
"map": MapXComArg,
"zip": ZipXComArg,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,14 +521,14 @@ Since it is common to want to transform the output data format for task mapping,
There are a couple of things to note:

#. The callable argument of ``map()`` (``create_copy_kwargs`` in the example) **must not** be a task, but a plain Python function. The transformation is as a part of the "pre-processing" of the downstream task (i.e. ``copy_files``), not a standalone task in the DAG.
#. The callable always take exactly one positional argument. This function is called for each item in the iterable used for task-mapping, similar to how Python's built-in ``map()`` works.
#. The callable argument of :func:`map()` (``create_copy_kwargs`` in the example) **must not** be a task, but a plain Python function. The transformation is as a part of the "pre-processing" of the downstream task (i.e. ``copy_files``), not a standalone task in the DAG.
#. The callable always take exactly one positional argument. This function is called for each item in the iterable used for task-mapping, similar to how Python's built-in :func:`map()` works.
#. Since the callable is executed as a part of the downstream task, you can use any existing techniques to write the task function. To mark a component as skipped, for example, you should raise ``AirflowSkipException``. Note that returning ``None`` **does not** work here.

Combining upstream data (aka "zipping")
=======================================

It is also common to want to combine multiple input sources into one task mapping iterable. This is generally known as "zipping" (like Python's built-in ``zip()`` function), and is also performed as pre-processing of the downstream task.
It is also common to want to combine multiple input sources into one task mapping iterable. This is generally known as "zipping" (like Python's built-in :func:`zip()` function), and is also performed as pre-processing of the downstream task.

This is especially useful for conditional logic in task mapping. For example, if you want to download files from S3, but rename those files, something like this would be possible:

Expand All @@ -552,7 +552,46 @@ This is especially useful for conditional logic in task mapping. For example, if
download_filea_from_a_rename.expand(filenames_a_b=filenames_a_b)
The ``zip`` function takes arbitrary positional arguments, and return an iterable of tuples of the positional arguments' count. By default, the zipped iterable's length is the same as the shortest of the zipped iterables, with superfluous items dropped. An optional keyword argument ``default`` can be passed to switch the behavior to match Python's ``itertools.zip_longest``—the zipped iterable will have the same length as the *longest* of the zipped iterables, with missing items filled with the value provided by ``default``.
Similar to the built-in :func:`zip`, you can zip an arbitrary number of iterables together to get an iterable of tuples of the positional arguments' count. By default, the zipped iterable's length is the same as the shortest of the zipped iterables, with superfluous items dropped. An optional keyword argument ``default`` can be passed to switch the behavior to match Python's :func:`itertools.zip_longest`—the zipped iterable will have the same length as the *longest* of the zipped iterables, with missing items filled with the value provided by ``default``.

Concatenating multiple upstreams
================================

Another common pattern to combine input sources is to run the same task against multiple iterables. It is of course totally valid to simply run the same code separately for each iterable, for example:

.. code-block:: python
list_filenames_a = S3ListOperator(
task_id="list_files_in_a",
bucket="bucket",
prefix="incoming/provider_a/{{ data_interval_start|ds }}",
)
list_filenames_b = S3ListOperator(
task_id="list_files_in_b",
bucket="bucket",
prefix="incoming/provider_b/{{ data_interval_start|ds }}",
)
@task
def download_file(filename):
S3Hook().download_file(filename)
# process file...
download_file.override(task_id="download_file_a").expand(filename=list_filenames_a.output)
download_file.override(task_id="download_file_b").expand(filename=list_filenames_b.output)
The DAG, however, would be both more scalable and easier to inspect if the tasks can be combined into one. This can done with ``concat``:

.. code-block:: python
# Tasks list_filenames_a and list_filenames_b, and download_file stay unchanged.
list_filenames_concat = list_filenames_a.concat(list_filenames_b)
download_file.expand(filename=list_filenames_concat)
This creates one single task to expand against both lists instead. You can ``concat`` an arbitrary number of iterables together (e.g. ``foo.concat(bar, rex)``); alternatively, since the return value is also an XCom reference, the ``concat`` calls can be chained (e.g. ``foo.concat(bar).concat(rex)``) to achieve the same result: one single iterable that concatenates all of them in order, similar to Python's :func:`itertools.chain`.

What data types can be expanded?
================================
Expand Down
69 changes: 69 additions & 0 deletions tests/models/test_xcom_arg_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,72 @@ def convert_zipped(zipped):
ti.run(session=session)

assert results == {"aa", "bbbb", "cccccc", "dddddddd"}


def test_xcom_concat(dag_maker, session):
from airflow.models.xcom_arg import _ConcatResult

agg_results = set()
all_results = None

with dag_maker(session=session) as dag:

@dag.task
def push_letters():
return ["a", "b", "c"]

@dag.task
def push_numbers():
return [1, 2]

@dag.task
def pull_one(value):
agg_results.add(value)

@dag.task
def pull_all(value):
assert isinstance(value, _ConcatResult)
assert value[0] == "a"
assert value[1] == "b"
assert value[2] == "c"
assert value[3] == 1
assert value[4] == 2
with pytest.raises(IndexError):
value[5]
assert value[-5] == "a"
assert value[-4] == "b"
assert value[-3] == "c"
assert value[-2] == 1
assert value[-1] == 2
with pytest.raises(IndexError):
value[-6]
nonlocal all_results
all_results = list(value)

pushed_values = push_letters().concat(push_numbers())

pull_one.expand(value=pushed_values)
pull_all(pushed_values)

dr = dag_maker.create_dagrun(session=session)

# Run "push_letters" and "push_numbers".
decision = dr.task_instance_scheduling_decisions(session=session)
assert len(decision.schedulable_tis) == 2
assert all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis)
for ti in decision.schedulable_tis:
ti.run(session=session)
session.commit()

# Run "pull_one" and "pull_all".
decision = dr.task_instance_scheduling_decisions(session=session)
assert len(decision.schedulable_tis) == 6
assert all(ti.task_id.startswith("pull_") for ti in decision.schedulable_tis)
for ti in decision.schedulable_tis:
ti.run(session=session)

assert agg_results == {"a", "b", "c", 1, 2}
assert all_results == ["a", "b", "c", 1, 2]

decision = dr.task_instance_scheduling_decisions(session=session)
assert not decision.schedulable_tis

0 comments on commit 81e7593

Please sign in to comment.