diff --git a/airflow/example_dags/example_task_group.py b/airflow/example_dags/example_task_group.py new file mode 100644 index 0000000000000..17134df58f981 --- /dev/null +++ b/airflow/example_dags/example_task_group.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example DAG demonstrating the usage of the TaskGroup.""" + +from airflow.models.dag import DAG +from airflow.operators.dummy_operator import DummyOperator +from airflow.utils.dates import days_ago +from airflow.utils.task_group import TaskGroup + +# [START howto_task_group] +with DAG(dag_id="example_task_group", start_date=days_ago(2)) as dag: + start = DummyOperator(task_id="start") + + # [START howto_task_group_section_1] + with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1: + task_1 = DummyOperator(task_id="task_1") + task_2 = DummyOperator(task_id="task_2") + task_3 = DummyOperator(task_id="task_3") + + task_1 >> [task_2, task_3] + # [END howto_task_group_section_1] + + # [START howto_task_group_section_2] + with TaskGroup("section_2", tooltip="Tasks for section_2") as section_2: + task_1 = DummyOperator(task_id="task_1") + + # [START howto_task_group_inner_section_2] + with TaskGroup("inner_section_2", tooltip="Tasks for inner_section2") as inner_section_2: + task_2 = DummyOperator(task_id="task_2") + task_3 = DummyOperator(task_id="task_3") + task_4 = DummyOperator(task_id="task_4") + + [task_2, task_3] >> task_4 + # [END howto_task_group_inner_section_2] + + # [END howto_task_group_section_2] + + end = DummyOperator(task_id='end') + + start >> section_1 >> section_2 >> end +# [END howto_task_group] diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 4058f055b54e7..6d48a27969a96 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -27,7 +27,8 @@ from abc import ABCMeta, abstractmethod from datetime import datetime, timedelta from typing import ( - Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple, Type, Union, + TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple, + Type, Union, ) import attr @@ -58,6 +59,9 @@ from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule +if TYPE_CHECKING: + from airflow.utils.task_group import TaskGroup # pylint: disable=cyclic-import + ScheduleInterval = Union[str, timedelta, relativedelta] TaskStateChangeCallback = Callable[[Context], None] @@ -360,9 +364,12 @@ def __init__( do_xcom_push: bool = True, inlets: Optional[Any] = None, outlets: Optional[Any] = None, + task_group: Optional["TaskGroup"] = None, **kwargs ): from airflow.models.dag import DagContext + from airflow.utils.task_group import TaskGroupContext + super().__init__() if kwargs: if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'): @@ -382,6 +389,11 @@ def __init__( ) validate_key(task_id) self.task_id = task_id + self.label = task_id + task_group = task_group or TaskGroupContext.get_current_task_group(dag) + if task_group: + self.task_id = task_group.child_id(task_id) + task_group.add(self) self.owner = owner self.email = email self.email_on_retry = email_on_retry @@ -609,7 +621,7 @@ def dag(self, dag: Any): elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self: dag.add_task(self) - self._dag = dag # pylint: disable=attribute-defined-outside-init + self._dag = dag def has_dag(self): """ @@ -1120,21 +1132,25 @@ def roots(self) -> List["BaseOperator"]: """Required by TaskMixin""" return [self] + @property + def leaves(self) -> List["BaseOperator"]: + """Required by TaskMixin""" + return [self] + def _set_relatives( self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], upstream: bool = False, ) -> None: """Sets relatives for the task or task list.""" - - if isinstance(task_or_task_list, Sequence): - task_like_object_list = task_or_task_list - else: - task_like_object_list = [task_or_task_list] + if not isinstance(task_or_task_list, Sequence): + task_or_task_list = [task_or_task_list] task_list: List["BaseOperator"] = [] - for task_object in task_like_object_list: - task_list.extend(task_object.roots) + for task_object in task_or_task_list: + task_object.update_relative(self, not upstream) + relatives = task_object.leaves if upstream else task_object.roots + task_list.extend(relatives) for task in task_list: if not isinstance(task, BaseOperator): diff --git a/airflow/models/dag.py b/airflow/models/dag.py index eecf6b41df140..886837c429a76 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -27,7 +27,9 @@ import warnings from collections import OrderedDict from datetime import datetime, timedelta -from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast +from typing import ( + TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast, +) import jinja2 import pendulum @@ -59,6 +61,9 @@ from airflow.utils.state import State from airflow.utils.types import DagRunType +if TYPE_CHECKING: + from airflow.utils.task_group import TaskGroup + log = logging.getLogger(__name__) ScheduleInterval = Union[str, timedelta, relativedelta] @@ -238,6 +243,8 @@ def __init__( jinja_environment_kwargs: Optional[Dict] = None, tags: Optional[List[str]] = None ): + from airflow.utils.task_group import TaskGroup + self.user_defined_macros = user_defined_macros self.user_defined_filters = user_defined_filters self.default_args = copy.deepcopy(default_args or {}) @@ -329,6 +336,7 @@ def __init__( self.jinja_environment_kwargs = jinja_environment_kwargs self.tags = tags + self._task_group = TaskGroup.create_root(self) def __repr__(self): return "".format(self=self) @@ -570,6 +578,10 @@ def tasks(self, val): def task_ids(self) -> List[str]: return list(self.task_dict.keys()) + @property + def task_group(self) -> "TaskGroup": + return self._task_group + @property def filepath(self) -> str: """ @@ -1240,7 +1252,6 @@ def sub_dag(self, task_regex, include_downstream=False, based on a regex that should match one or many tasks, and includes upstream and downstream neighbours based on the flag passed. """ - # deep-copying self.task_dict takes a long time, and we don't want all # the tasks anyway, so we copy the tasks manually later task_dict = self.task_dict @@ -1261,9 +1272,38 @@ def sub_dag(self, task_regex, include_downstream=False, # Make sure to not recursively deepcopy the dag while copying the task dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag}) for t in regex_match + also_include} + + # Remove tasks not included in the subdag from task_group + def remove_excluded(group): + for child in list(group.children.values()): + if isinstance(child, BaseOperator): + if child.task_id not in dag.task_dict: + group.children.pop(child.task_id) + else: + # The tasks in the subdag are a copy of tasks in the original dag + # so update the reference in the TaskGroups too. + group.children[child.task_id] = dag.task_dict[child.task_id] + else: + remove_excluded(child) + + # Remove this TaskGroup if it doesn't contain any tasks in this subdag + if not child.children: + group.children.pop(child.group_id) + + remove_excluded(dag.task_group) + + # Removing upstream/downstream references to tasks and TaskGroups that did not make + # the cut. + subdag_task_groups = dag.task_group.get_task_group_dict() + for group in subdag_task_groups.values(): + group.upstream_group_ids = group.upstream_group_ids.intersection(subdag_task_groups.keys()) + group.downstream_group_ids = group.downstream_group_ids.intersection(subdag_task_groups.keys()) + group.upstream_task_ids = group.upstream_task_ids.intersection(dag.task_dict.keys()) + group.downstream_task_ids = group.downstream_task_ids.intersection(dag.task_dict.keys()) + for t in dag.tasks: # Removing upstream/downstream references to tasks that did not - # made the cut + # make the cut t._upstream_task_ids = t.upstream_task_ids.intersection(dag.task_dict.keys()) t._downstream_task_ids = t.downstream_task_ids.intersection( dag.task_dict.keys()) @@ -1357,12 +1397,15 @@ def add_task(self, task): elif task.end_date and self.end_date: task.end_date = min(task.end_date, self.end_date) - if task.task_id in self.task_dict and self.task_dict[task.task_id] is not task: + if ((task.task_id in self.task_dict and self.task_dict[task.task_id] is not task) + or task.task_id in self._task_group.used_group_ids): raise DuplicateTaskIdFound( "Task id '{}' has already been added to the DAG".format(task.task_id)) else: self.task_dict[task.task_id] = task task.dag = self + # Add task_id to used_group_ids to prevent group_id and task_id collisions. + self._task_group.used_group_ids.add(task.task_id) self.task_count = len(self.task_dict) diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index a3d42242ca4f7..cfdc714f824ce 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -33,6 +33,11 @@ def roots(self): """Should return list of root operator List[BaseOperator]""" raise NotImplementedError() + @property + def leaves(self): + """Should return list of leaf operator List[BaseOperator]""" + raise NotImplementedError() + @abstractmethod def set_upstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]): """ @@ -47,6 +52,12 @@ def set_downstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]): """ raise NotImplementedError() + def update_relative(self, other: "TaskMixin", upstream=True) -> None: + """ + Update relationship information about another TaskMixin. Default is no-op. + Override if necessary. + """ + def __lshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]): """ Implements Task << Task diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 0f647bf4f86f2..b9faaabac5aad 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -102,6 +102,11 @@ def roots(self) -> List[BaseOperator]: """Required by TaskMixin""" return [self._operator] + @property + def leaves(self) -> List[BaseOperator]: + """Required by TaskMixin""" + return [self._operator] + @property def key(self) -> str: """Returns keys of this XComArg""" diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index b6fae5e50d45d..9a30231b046a9 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -43,3 +43,4 @@ class DagAttributeTypes(str, Enum): SET = 'set' TUPLE = 'tuple' POD = 'k8s.V1Pod' + TASK_GROUP = 'taskgroup' diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index 49de949677e8d..9056eaab97d58 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -96,7 +96,11 @@ "_default_view": { "type" : "string"}, "_access_control": {"$ref": "#/definitions/dict" }, "is_paused_upon_creation": { "type": "boolean" }, - "tags": { "type": "array" } + "tags": { "type": "array" }, + "_task_group": {"anyOf": [ + { "type": "null" }, + { "$ref": "#/definitions/task_group" } + ]} }, "required": [ "_dag_id", @@ -125,6 +129,7 @@ "_task_module": { "type": "string" }, "_operator_extra_links": { "$ref": "#/definitions/extra_links" }, "task_id": { "type": "string" }, + "label": { "type": "string" }, "owner": { "type": "string" }, "start_date": { "$ref": "#/definitions/datetime" }, "end_date": { "$ref": "#/definitions/datetime" }, @@ -156,6 +161,47 @@ } }, "additionalProperties": true + }, + "task_group": { + "$comment": "A TaskGroup containing tasks", + "type": "object", + "required": [ + "_group_id", + "prefix_group_id", + "children", + "tooltip", + "ui_color", + "ui_fgcolor", + "upstream_group_ids", + "downstream_group_ids", + "upstream_task_ids", + "downstream_task_ids" + ], + "properties": { + "_group_id": {"anyOf": [{"type": "null"}, { "type": "string" }]}, + "prefix_group_id": { "type": "boolean" }, + "children": { "$ref": "#/definitions/dict" }, + "tooltip": { "type": "string" }, + "ui_color": { "type": "string" }, + "ui_fgcolor": { "type": "string" }, + "upstream_group_ids": { + "type": "array", + "items": { "type": "string" } + }, + "downstream_group_ids": { + "type": "array", + "items": { "type": "string" } + }, + "upstream_task_ids": { + "type": "array", + "items": { "type": "string" } + }, + "downstream_task_ids": { + "type": "array", + "items": { "type": "string" } + } + }, + "additionalProperties": false } }, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 6129a9d52e4b1..41c6bc7098ebb 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -44,6 +44,7 @@ from airflow.settings import json from airflow.utils.code_utils import get_python_source from airflow.utils.module_loading import import_string +from airflow.utils.task_group import TaskGroup log = logging.getLogger(__name__) FAILED = 'serialization_failed' @@ -221,6 +222,8 @@ def _serialize(cls, var: Any) -> Any: # Unfortunately there is no support for r # FIXME: casts tuple to list in customized serialization in future. return cls._encode( [cls._serialize(v) for v in var], type_=DAT.TUPLE) + elif isinstance(var, TaskGroup): + return SerializedTaskGroup.serialize_task_group(var) else: log.debug('Cast type %s to str in serialization.', type(var)) return str(var) @@ -376,6 +379,10 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator: # Extra Operator Links defined in Plugins op_extra_links_from_plugin = {} + if "label" not in encoded_op: + # Handle deserialization of old data before the introduction of TaskGroup + encoded_op["label"] = encoded_op["task_id"] + for ope in plugins_manager.operator_extra_links: for operator in ope.operators: if operator.__name__ == encoded_op["_task_type"] and \ @@ -570,6 +577,7 @@ def serialize_dag(cls, dag: DAG) -> dict: serialize_dag = cls.serialize_to_json(dag, cls._decorated_fields) serialize_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()] + serialize_dag['_task_group'] = SerializedTaskGroup.serialize_task_group(dag.task_group) return serialize_dag @classmethod @@ -598,6 +606,22 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': setattr(dag, k, v) + # Set _task_group + # pylint: disable=protected-access + if "_task_group" in encoded_dag: + dag._task_group = SerializedTaskGroup.deserialize_task_group( # type: ignore + encoded_dag["_task_group"], + None, + dag.task_dict + ) + else: + # This must be old data that had no task_group. Create a root TaskGroup and add + # all tasks to it. + dag._task_group = TaskGroup.create_root(dag) + for task in dag.tasks: + dag.task_group.add(task) + # pylint: enable=protected-access + keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys() for k in keys_to_set_none: setattr(dag, k, None) @@ -641,3 +665,71 @@ def from_dict(cls, serialized_obj: dict) -> 'SerializedDAG': if ver != cls.SERIALIZER_VERSION: raise ValueError("Unsure how to deserialize version {!r}".format(ver)) return cls.deserialize_dag(serialized_obj['dag']) + + +class SerializedTaskGroup(TaskGroup, BaseSerialization): + """ + A JSON serializable representation of TaskGroup. + """ + @classmethod + def serialize_task_group(cls, task_group: TaskGroup) -> Optional[Union[Dict[str, Any]]]: + """ + Serializes TaskGroup into a JSON object. + """ + if not task_group: + return None + + serialize_group = { + "_group_id": task_group._group_id, # pylint: disable=protected-access + "prefix_group_id": task_group.prefix_group_id, + "tooltip": task_group.tooltip, + "ui_color": task_group.ui_color, + "ui_fgcolor": task_group.ui_fgcolor, + "children": { + label: (DAT.OP, child.task_id) + if isinstance(child, BaseOperator) else + (DAT.TASK_GROUP, SerializedTaskGroup.serialize_task_group(child)) + for label, child in task_group.children.items() + }, + "upstream_group_ids": cls._serialize(list(task_group.upstream_group_ids)), + "downstream_group_ids": cls._serialize(list(task_group.downstream_group_ids)), + "upstream_task_ids": cls._serialize(list(task_group.upstream_task_ids)), + "downstream_task_ids": cls._serialize(list(task_group.downstream_task_ids)), + + } + + return serialize_group + + @classmethod + def deserialize_task_group( + cls, + encoded_group: Dict[str, Any], + parent_group: Optional[TaskGroup], + task_dict: Dict[str, BaseOperator] + ) -> Optional[TaskGroup]: + """ + Deserializes a TaskGroup from a JSON object. + """ + if not encoded_group: + return None + + group_id = cls._deserialize(encoded_group["_group_id"]) + kwargs = { + key: cls._deserialize(encoded_group[key]) + for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"] + } + group = SerializedTaskGroup( + group_id=group_id, + parent_group=parent_group, + **kwargs + ) + group.children = { + label: task_dict[val] if _type == DAT.OP # type: ignore + else SerializedTaskGroup.deserialize_task_group(val, group, task_dict) for label, (_type, val) + in encoded_group["children"].items() + } + group.upstream_group_ids = set(cls._deserialize(encoded_group["upstream_group_ids"])) + group.downstream_group_ids = set(cls._deserialize(encoded_group["downstream_group_ids"])) + group.upstream_task_ids = set(cls._deserialize(encoded_group["upstream_task_ids"])) + group.downstream_task_ids = set(cls._deserialize(encoded_group["downstream_task_ids"])) + return group diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py new file mode 100644 index 0000000000000..84cc540147232 --- /dev/null +++ b/airflow/utils/task_group.py @@ -0,0 +1,379 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped +together when the DAG is displayed graphically. +""" + +from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union + +from airflow.exceptions import AirflowException, DuplicateTaskIdFound +from airflow.models.taskmixin import TaskMixin + +if TYPE_CHECKING: + from airflow.models.baseoperator import BaseOperator + from airflow.models.dag import DAG + + +class TaskGroup(TaskMixin): + """ + A collection of tasks. When set_downstream() or set_upstream() are called on the + TaskGroup, it is applied across all tasks within the group if necessary. + + :param group_id: a unique, meaningful id for the TaskGroup. group_id must not conflict + with group_id of TaskGroup or task_id of tasks in the DAG. Root TaskGroup has group_id + set to None. + :type group_id: str + :param prefix_group_id: If set to True, child task_id and group_id will be prefixed with + this TaskGroup's group_id. If set to False, child task_id and group_id are not prefixed. + Default is True. + :type prerfix_group_id: bool + :param parent_group: The parent TaskGroup of this TaskGroup. parent_group is set to None + for the root TaskGroup. + :type parent_group: TaskGroup + :param dag: The DAG that this TaskGroup belongs to. + :type dag: airflow.models.DAG + :param tooltip: The tooltip of the TaskGroup node when displayed in the UI + :type tooltip: str + :param ui_color: The fill color of the TaskGroup node when displayed in the UI + :type ui_color: str + :param ui_fgcolor: The label color of the TaskGroup node when displayed in the UI + :type ui_fgcolor: str + """ + + def __init__( + self, + group_id: Optional[str], + prefix_group_id: bool = True, + parent_group: Optional["TaskGroup"] = None, + dag: Optional["DAG"] = None, + tooltip: str = "", + ui_color: str = "CornflowerBlue", + ui_fgcolor: str = "#000", + ): + from airflow.models.dag import DagContext + + self.prefix_group_id = prefix_group_id + + if group_id is None: + # This creates a root TaskGroup. + if parent_group: + raise AirflowException("Root TaskGroup cannot have parent_group") + # used_group_ids is shared across all TaskGroups in the same DAG to keep track + # of used group_id to avoid duplication. + self.used_group_ids: Set[Optional[str]] = set() + self._parent_group = None + else: + if not isinstance(group_id, str): + raise ValueError("group_id must be str") + if not group_id: + raise ValueError("group_id must not be empty") + + dag = dag or DagContext.get_current_dag() + + if not parent_group and not dag: + raise AirflowException("TaskGroup can only be used inside a dag") + + self._parent_group = parent_group or TaskGroupContext.get_current_task_group(dag) + if not self._parent_group: + raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup") + self.used_group_ids = self._parent_group.used_group_ids + + self._group_id = group_id + if self.group_id in self.used_group_ids: + raise DuplicateTaskIdFound(f"group_id '{self.group_id}' has already been added to the DAG") + self.used_group_ids.add(self.group_id) + self.used_group_ids.add(self.downstream_join_id) + self.used_group_ids.add(self.upstream_join_id) + self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {} + if self._parent_group: + self._parent_group.add(self) + + self.tooltip = tooltip + self.ui_color = ui_color + self.ui_fgcolor = ui_fgcolor + + # Keep track of TaskGroups or tasks that depend on this entire TaskGroup separately + # so that we can optimize the number of edges when entire TaskGroups depend on each other. + self.upstream_group_ids: Set[Optional[str]] = set() + self.downstream_group_ids: Set[Optional[str]] = set() + self.upstream_task_ids: Set[Optional[str]] = set() + self.downstream_task_ids: Set[Optional[str]] = set() + + @classmethod + def create_root(cls, dag: "DAG") -> "TaskGroup": + """ + Create a root TaskGroup with no group_id or parent. + """ + return cls(group_id=None, dag=dag) + + @property + def is_root(self) -> bool: + """ + Returns True if this TaskGroup is the root TaskGroup. Otherwise False + """ + return not self.group_id + + def __iter__(self): + for child in self.children.values(): + if isinstance(child, TaskGroup): + for inner_task in child: + yield inner_task + else: + yield child + + def add(self, task: Union["BaseOperator", "TaskGroup"]) -> None: + """ + Add a task to this TaskGroup. + """ + key = task.group_id if isinstance(task, TaskGroup) else task.task_id + + if key in self.children: + raise DuplicateTaskIdFound(f"Task id '{key}' has already been added to the DAG") + + if isinstance(task, TaskGroup): + if task.children: + raise AirflowException("Cannot add a non-empty TaskGroup") + + self.children[key] = task # type: ignore + + @property + def group_id(self) -> Optional[str]: + """ + group_id of this TaskGroup. + """ + if self._parent_group and self._parent_group.prefix_group_id and self._parent_group.group_id: + return self._parent_group.child_id(self._group_id) + + return self._group_id + + @property + def label(self) -> Optional[str]: + """ + group_id excluding parent's group_id used as the node label in UI. + """ + return self._group_id + + def update_relative(self, other: "TaskMixin", upstream=True) -> None: + """ + Overrides TaskMixin.update_relative. + + Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids + accordingly so that we can reduce the number of edges when displaying Graph View. + """ + from airflow.models.baseoperator import BaseOperator + + if isinstance(other, TaskGroup): + # Handles setting relationship between a TaskGroup and another TaskGroup + if upstream: + parent, child = (self, other) + else: + parent, child = (other, self) + + parent.upstream_group_ids.add(child.group_id) + child.downstream_group_ids.add(parent.group_id) + else: + # Handles setting relationship between a TaskGroup and a task + for task in other.roots: + if not isinstance(task, BaseOperator): + raise AirflowException("Relationships can only be set between TaskGroup " + f"or operators; received {task.__class__.__name__}") + + if upstream: + self.upstream_task_ids.add(task.task_id) + else: + self.downstream_task_ids.add(task.task_id) + + def _set_relative( + self, + task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]], + upstream: bool = False + ) -> None: + """ + Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. + Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids. + """ + if upstream: + for task in self.get_roots(): + task.set_upstream(task_or_task_list) + else: + for task in self.get_leaves(): + task.set_downstream(task_or_task_list) + + if not isinstance(task_or_task_list, Sequence): + task_or_task_list = [task_or_task_list] + + for task_like in task_or_task_list: + self.update_relative(task_like, upstream) + + def set_downstream( + self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]] + ) -> None: + """ + Set a TaskGroup/task/list of task downstream of this TaskGroup. + """ + self._set_relative(task_or_task_list, upstream=False) + + def set_upstream( + self, task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]] + ) -> None: + """ + Set a TaskGroup/task/list of task upstream of this TaskGroup. + """ + self._set_relative(task_or_task_list, upstream=True) + + def __enter__(self): + TaskGroupContext.push_context_managed_task_group(self) + return self + + def __exit__(self, _type, _value, _tb): + TaskGroupContext.pop_context_managed_task_group() + + def has_task(self, task: "BaseOperator") -> bool: + """ + Returns True if this TaskGroup or its children TaskGroups contains the given task. + """ + if task.task_id in self.children: + return True + + return any(child.has_task(task) for child in self.children.values() if isinstance(child, TaskGroup)) + + @property + def roots(self) -> List["BaseOperator"]: + """Required by TaskMixin""" + return list(self.get_roots()) + + @property + def leaves(self) -> List["BaseOperator"]: + """Required by TaskMixin""" + return list(self.get_leaves()) + + def get_roots(self) -> Generator["BaseOperator", None, None]: + """ + Returns a generator of tasks that are root tasks, i.e. those with no upstream + dependencies within the TaskGroup. + """ + for task in self: + if not any(self.has_task(parent) for parent in task.get_direct_relatives(upstream=True)): + yield task + + def get_leaves(self) -> Generator["BaseOperator", None, None]: + """ + Returns a generator of tasks that are leaf tasks, i.e. those with no downstream + dependencies within the TaskGroup + """ + for task in self: + if not any(self.has_task(child) for child in task.get_direct_relatives(upstream=False)): + yield task + + def child_id(self, label): + """ + Prefix label with group_id if prefix_group_id is True. Otherwise return the label + as-is. + """ + if self.prefix_group_id and self.group_id: + return f"{self.group_id}.{label}" + + return label + + @property + def upstream_join_id(self) -> str: + """ + If this TaskGroup has immediate upstream TaskGroups or tasks, a dummy node called + upstream_join_id will be created in Graph View to join the outgoing edges from this + TaskGroup to reduce the total number of edges needed to be displayed. + """ + return f"{self.group_id}.upstream_join_id" + + @property + def downstream_join_id(self) -> str: + """ + If this TaskGroup has immediate downstream TaskGroups or tasks, a dummy node called + downstream_join_id will be created in Graph View to join the outgoing edges from this + TaskGroup to reduce the total number of edges needed to be displayed. + """ + return f"{self.group_id}.downstream_join_id" + + def get_task_group_dict(self) -> Dict[str, "TaskGroup"]: + """ + Returns a flat dictionary of group_id: TaskGroup + """ + task_group_map = {} + + def build_map(task_group): + if not isinstance(task_group, TaskGroup): + return + + task_group_map[task_group.group_id] = task_group + + for child in task_group.children.values(): + build_map(child) + + build_map(self) + return task_group_map + + def get_child_by_label(self, label: str) -> Union["BaseOperator", "TaskGroup"]: + """ + Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix) + """ + return self.children[self.child_id(label)] + + +class TaskGroupContext: + """ + TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager. + """ + + _context_managed_task_group: Optional[TaskGroup] = None + _previous_context_managed_task_groups: List[TaskGroup] = [] + + @classmethod + def push_context_managed_task_group(cls, task_group: TaskGroup): + """ + Push a TaskGroup into the list of managed TaskGroups. + """ + if cls._context_managed_task_group: + cls._previous_context_managed_task_groups.append(cls._context_managed_task_group) + cls._context_managed_task_group = task_group + + @classmethod + def pop_context_managed_task_group(cls) -> Optional[TaskGroup]: + """ + Pops the last TaskGroup from the list of manged TaskGroups and update the current TaskGroup. + """ + old_task_group = cls._context_managed_task_group + if cls._previous_context_managed_task_groups: + cls._context_managed_task_group = cls._previous_context_managed_task_groups.pop() + else: + cls._context_managed_task_group = None + return old_task_group + + @classmethod + def get_current_task_group(cls, dag: Optional["DAG"]) -> Optional[TaskGroup]: + """ + Get the current TaskGroup. + """ + from airflow.models.dag import DagContext + + if not cls._context_managed_task_group: + dag = dag or DagContext.get_current_dag() + if dag: + # If there's currently a DAG but no TaskGroup, return the root TaskGroup of the dag. + return dag.task_group + + return cls._context_managed_task_group diff --git a/airflow/www/static/css/graph.css b/airflow/www/static/css/graph.css index 7f1b8189e7308..ee1d0930670e9 100644 --- a/airflow/www/static/css/graph.css +++ b/airflow/www/static/css/graph.css @@ -36,12 +36,26 @@ svg { stroke-width: 1px; } +g.cluster rect { + stroke: white; + stroke-dasharray: 5; + rx: 5; + ry: 5; + opacity: 0.5; +} + g.node rect { stroke: #fff; stroke-width: 3px; cursor: pointer; } +g.node circle { + stroke: black; + stroke-width: 3px; + cursor: pointer; +} + g.node .label { font-size: inherit; font-weight: normal; diff --git a/airflow/www/templates/airflow/graph.html b/airflow/www/templates/airflow/graph.html index f4ec0b62c35b5..6d6d566d5512d 100644 --- a/airflow/www/templates/airflow/graph.html +++ b/airflow/www/templates/airflow/graph.html @@ -101,6 +101,8 @@ + + - {% endblock %} diff --git a/airflow/www/views.py b/airflow/www/views.py index 07c8d2f8c98c5..4c5e2b193013a 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -58,6 +58,7 @@ from airflow.jobs.base_job import BaseJob from airflow.jobs.scheduler_job import SchedulerJob from airflow.models import Connection, DagModel, DagTag, Log, SlaMiss, TaskFail, XCom, errors +from airflow.models.baseoperator import BaseOperator from airflow.models.dagcode import DagCode from airflow.models.dagrun import DagRun, DagRunType from airflow.models.taskinstance import TaskInstance @@ -147,6 +148,163 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): } +def task_group_to_dict(task_group): + """ + Create a nested dict representation of this TaskGroup and its children used to construct + the Graph View. + """ + if isinstance(task_group, BaseOperator): + return { + 'id': task_group.task_id, + 'value': { + 'label': task_group.label, + 'labelStyle': f"fill:{task_group.ui_fgcolor};", + 'style': f"fill:{task_group.ui_color};", + 'rx': 5, + 'ry': 5, + } + } + + children = [task_group_to_dict(child) for child in + sorted(task_group.children.values(), key=lambda t: t.label)] + + if task_group.upstream_group_ids or task_group.upstream_task_ids: + children.append({ + 'id': task_group.upstream_join_id, + 'value': { + 'label': '', + 'labelStyle': f"fill:{task_group.ui_fgcolor};", + 'style': f"fill:{task_group.ui_color};", + 'shape': 'circle', + } + }) + + if task_group.downstream_group_ids or task_group.downstream_task_ids: + # This is the join node used to reduce the number of edges between two TaskGroup. + children.append({ + 'id': task_group.downstream_join_id, + 'value': { + 'label': '', + 'labelStyle': f"fill:{task_group.ui_fgcolor};", + 'style': f"fill:{task_group.ui_color};", + 'shape': 'circle', + } + }) + + return { + "id": task_group.group_id, + 'value': { + 'label': task_group.label, + 'labelStyle': f"fill:{task_group.ui_fgcolor};", + 'style': f"fill:{task_group.ui_color}", + 'rx': 5, + 'ry': 5, + 'clusterLabelPos': 'top', + }, + 'tooltip': task_group.tooltip, + 'children': children + } + + +def dag_edges(dag): + """ + Create the list of edges needed to construct the Graph View. + + A special case is made if a TaskGroup is immediately upstream/downstream of another + TaskGroup or task. Two dummy nodes named upstream_join_id and downstream_join_id are + created for the TaskGroup. Instead of drawing an edge onto every task in the TaskGroup, + all edges are directed onto the dummy nodes. This is to cut down the number of edges on + the graph. + + For example: A DAG with TaskGroups group1 and group2: + group1: task1, task2, task3 + group2: task4, task5, task6 + + group2 is downstream of group1: + group1 >> group2 + + Edges to add (This avoids having to create edges between every task in group1 and group2): + task1 >> downstream_join_id + task2 >> downstream_join_id + task3 >> downstream_join_id + downstream_join_id >> upstream_join_id + upstream_join_id >> task4 + upstream_join_id >> task5 + upstream_join_id >> task6 + """ + + # Edges to add between TaskGroup + edges_to_add = set() + # Edges to remove between individual tasks that are replaced by edges_to_add. + edges_to_skip = set() + + task_group_map = dag.task_group.get_task_group_dict() + + def collect_edges(task_group): + """ + Update edges_to_add and edges_to_skip according to TaskGroups. + """ + if isinstance(task_group, BaseOperator): + return + + for target_id in task_group.downstream_group_ids: + # For every TaskGroup immediately downstream, add edges between downstream_join_id + # and upstream_join_id. Skip edges between individual tasks of the TaskGroups. + target_group = task_group_map[target_id] + edges_to_add.add((task_group.downstream_join_id, target_group.upstream_join_id)) + + for child in task_group.get_leaves(): + edges_to_add.add((child.task_id, task_group.downstream_join_id)) + for target in target_group.get_roots(): + edges_to_skip.add((child.task_id, target.task_id)) + edges_to_skip.add((child.task_id, target_group.upstream_join_id)) + + for child in target_group.get_roots(): + edges_to_add.add((target_group.upstream_join_id, child.task_id)) + edges_to_skip.add((task_group.downstream_join_id, child.task_id)) + + # For every individual task immediately downstream, add edges between downstream_join_id and + # the downstream task. Skip edges between individual tasks of the TaskGroup and the + # downstream task. + for target_id in task_group.downstream_task_ids: + edges_to_add.add((task_group.downstream_join_id, target_id)) + + for child in task_group.get_leaves(): + edges_to_add.add((child.task_id, task_group.downstream_join_id)) + edges_to_skip.add((child.task_id, target_id)) + + # For every individual task immediately upstream, add edges between the upstream task + # and upstream_join_id. Skip edges between the upstream task and individual tasks + # of the TaskGroup. + for source_id in task_group.upstream_task_ids: + edges_to_add.add((source_id, task_group.upstream_join_id)) + for child in task_group.get_roots(): + edges_to_add.add((task_group.upstream_join_id, child.task_id)) + edges_to_skip.add((source_id, child.task_id)) + + for child in task_group.children.values(): + collect_edges(child) + + collect_edges(dag.task_group) + + # Collect all the edges between individual tasks + edges = set() + + def get_downstream(task): + for child in task.downstream_list: + edge = (task.task_id, child.task_id) + if edge not in edges: + edges.add(edge) + get_downstream(child) + + for root in dag.roots: + get_downstream(root) + + return [{'source_id': source_id, 'target_id': target_id} + for source_id, target_id + in sorted(edges.union(edges_to_add) - edges_to_skip)] + + ###################################################################################### # Error handlers ###################################################################################### @@ -1608,32 +1766,8 @@ def graph(self, session=None): arrange = request.args.get('arrange', dag.orientation) - nodes = [] - edges = [] - for dag_task in dag.tasks: - nodes.append({ - 'id': dag_task.task_id, - 'value': { - 'label': dag_task.task_id, - 'labelStyle': "fill:{0};".format(dag_task.ui_fgcolor), - 'style': "fill:{0};".format(dag_task.ui_color), - 'rx': 5, - 'ry': 5, - } - }) - - def get_downstream(task): - for downstream_task in task.downstream_list: - edge = { - 'source_id': task.task_id, - 'target_id': downstream_task.task_id, - } - if edge not in edges: - edges.append(edge) - get_downstream(downstream_task) - - for dag_task in dag.roots: - get_downstream(dag_task) + nodes = task_group_to_dict(dag.task_group) + edges = dag_edges(dag) dt_nr_dr_data = get_date_time_num_runs_dag_runs_form_data(request, session, dag) dt_nr_dr_data['arrange'] = arrange diff --git a/docs/concepts.rst b/docs/concepts.rst index 5d2be9739b33c..8e375825afa1f 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -939,6 +939,48 @@ See ``airflow/example_dags`` for a demonstration. Note that airflow pool is not honored by SubDagOperator. Hence resources could be consumed by SubdagOperators. + +TaskGroup +========= +TaskGroup can be used to organize tasks into hierarchical groups in Graph View. It is +useful for creating repeating patterns and cutting down visual clutter. Unlike SubDagOperator, +TaskGroup is a UI grouping concept. Tasks in TaskGroups live on the same original DAG. They +honor all the pool configurations. + +Dependency relationships can be applied across all tasks in a TaskGroup with the ``>>`` and ``<<`` +operators. For example, the following code puts ``task1`` and ``task2`` in TaskGroup ``group1`` +and then puts both tasks upstream of ``task3``: + +.. code-block:: python + + with TaskGroup("group1") as group1: + task1 = DummyOperator(task_id="task1") + task2 = DummyOperator(task_id="task2") + + task3 = DummyOperator(task_id="task3") + + group1 >> task3 + +.. note:: + By default, child tasks and TaskGroups have their task_id and group_id prefixed with the + group_id of their parent TaskGroup. This ensures uniqueness of group_id and task_id throughout + the DAG. To disable the prefixing, pass ``prefix_group_id=False`` when creating the TaskGroup. + This then gives the user full control over the actual group_id and task_id. They have to ensure + group_id and task_id are unique throughout the DAG. The option ``prefix_group_id=False`` is + mainly useful for putting tasks on existing DAGs into TaskGroup without altering their task_id. + +Here is a more complicated example DAG with multiple levels of nested TaskGroups: + +.. exampleinclude:: /../airflow/example_dags/example_task_group.py + :language: python + :start-after: [START howto_task_group] + :end-before: [END howto_task_group] + +This animated gif shows the UI interactions. TaskGroups are expanded or collapsed when clicked: + +.. image:: img/task_group.gif + + SLAs ==== diff --git a/docs/img/task_group.gif b/docs/img/task_group.gif new file mode 100644 index 0000000000000..ac4f6e943ca27 Binary files /dev/null and b/docs/img/task_group.gif differ diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 75c35e8f0263a..1b3a9930754c6 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -67,6 +67,17 @@ } }, "start_date": 1564617600.0, + '_task_group': {'_group_id': None, + 'prefix_group_id': True, + 'children': {'bash_task': ('operator', 'bash_task'), + 'custom_task': ('operator', 'custom_task')}, + 'tooltip': '', + 'ui_color': 'CornflowerBlue', + 'ui_fgcolor': '#000', + 'upstream_group_ids': [], + 'downstream_group_ids': [], + 'upstream_task_ids': [], + 'downstream_task_ids': []}, "is_paused_upon_creation": False, "_dag_id": "simple_dag", "fileloc": None, @@ -83,6 +94,7 @@ "ui_fgcolor": "#000", "template_fields": ['bash_command', 'env'], "bash_command": "echo {{ task.task_id }}", + 'label': 'bash_task', "_task_type": "BashOperator", "_task_module": "airflow.operators.bash", "pool": "default_pool", @@ -107,6 +119,7 @@ "_task_type": "CustomOperator", "_task_module": "tests.test_utils.mock_operators", "pool": "default_pool", + 'label': 'custom_task', }, ], "timezone": "UTC", @@ -329,6 +342,7 @@ def validate_deserialized_dag(self, serialized_dag, dag): # Need to check fields in it, to exclude functions 'default_args', + "_task_group" } for field in fields_to_check: assert getattr(serialized_dag, field) == getattr(dag, field), \ @@ -765,6 +779,7 @@ def test_no_new_fields_added_to_base_operator(self): 'execution_timeout': None, 'executor_config': {}, 'inlets': [], + 'label': '10', 'max_retry_delay': None, 'on_execute_callback': None, 'on_failure_callback': None, @@ -804,3 +819,51 @@ def test_no_new_fields_added_to_base_operator(self): !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! """ ) + + def test_task_group_serialization(self): + """ + Test TaskGroup serialization/deserialization. + """ + from airflow.operators.dummy_operator import DummyOperator + from airflow.utils.task_group import TaskGroup + + execution_date = datetime(2020, 1, 1) + with DAG("test_task_group_serialization", start_date=execution_date) as dag: + task1 = DummyOperator(task_id="task1") + with TaskGroup("group234") as group234: + _ = DummyOperator(task_id="task2") + + with TaskGroup("group34") as group34: + _ = DummyOperator(task_id="task3") + _ = DummyOperator(task_id="task4") + + task5 = DummyOperator(task_id="task5") + task1 >> group234 + group34 >> task5 + + dag_dict = SerializedDAG.to_dict(dag) + SerializedDAG.validate_schema(dag_dict) + json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) + self.validate_deserialized_dag(json_dag, dag) + + serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) + + assert serialized_dag.task_group.children + assert serialized_dag.task_group.children.keys() == dag.task_group.children.keys() + + def check_task_group(node): + try: + children = node.children.values() + except AttributeError: + # Round-trip serialization and check the result + expected_serialized = SerializedBaseOperator.serialize_operator(dag.get_task(node.task_id)) + expected_deserialized = SerializedBaseOperator.deserialize_operator(expected_serialized) + expected_dict = SerializedBaseOperator.serialize_operator(expected_deserialized) + assert node + assert SerializedBaseOperator.serialize_operator(node) == expected_dict + return + + for child in children: + check_task_group(child) + + check_task_group(serialized_dag.task_group) diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py new file mode 100644 index 0000000000000..c4f7a125dec4c --- /dev/null +++ b/tests/utils/test_task_group.py @@ -0,0 +1,561 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pendulum +import pytest + +from airflow.models import DAG +from airflow.operators.dummy_operator import DummyOperator +from airflow.utils.task_group import TaskGroup +from airflow.www.views import dag_edges, task_group_to_dict + +EXPECTED_JSON = { + 'id': None, + 'value': { + 'label': None, + 'labelStyle': 'fill:#000;', + 'style': 'fill:CornflowerBlue', + 'rx': 5, + 'ry': 5, + 'clusterLabelPos': 'top', + }, + 'tooltip': '', + 'children': [ + { + 'id': 'group234', + 'value': { + 'label': 'group234', + 'labelStyle': 'fill:#000;', + 'style': 'fill:CornflowerBlue', + 'rx': 5, + 'ry': 5, + 'clusterLabelPos': 'top', + }, + 'tooltip': '', + 'children': [ + { + 'id': 'group234.group34', + 'value': { + 'label': 'group34', + 'labelStyle': 'fill:#000;', + 'style': 'fill:CornflowerBlue', + 'rx': 5, + 'ry': 5, + 'clusterLabelPos': 'top', + }, + 'tooltip': '', + 'children': [ + { + 'id': 'group234.group34.task3', + 'value': { + 'label': 'task3', + 'labelStyle': 'fill:#000;', + 'style': 'fill:#e8f7e4;', + 'rx': 5, + 'ry': 5, + }, + }, + { + 'id': 'group234.group34.task4', + 'value': { + 'label': 'task4', + 'labelStyle': 'fill:#000;', + 'style': 'fill:#e8f7e4;', + 'rx': 5, + 'ry': 5, + }, + }, + { + 'id': 'group234.group34.downstream_join_id', + 'value': { + 'label': '', + 'labelStyle': 'fill:#000;', + 'style': 'fill:CornflowerBlue;', + 'shape': 'circle', + }, + }, + ], + }, + { + 'id': 'group234.task2', + 'value': { + 'label': 'task2', + 'labelStyle': 'fill:#000;', + 'style': 'fill:#e8f7e4;', + 'rx': 5, + 'ry': 5, + }, + }, + { + 'id': 'group234.upstream_join_id', + 'value': { + 'label': '', + 'labelStyle': 'fill:#000;', + 'style': 'fill:CornflowerBlue;', + 'shape': 'circle', + }, + }, + ], + }, + { + 'id': 'task1', + 'value': { + 'label': 'task1', + 'labelStyle': 'fill:#000;', + 'style': 'fill:#e8f7e4;', + 'rx': 5, + 'ry': 5, + }, + }, + { + 'id': 'task5', + 'value': { + 'label': 'task5', + 'labelStyle': 'fill:#000;', + 'style': 'fill:#e8f7e4;', + 'rx': 5, + 'ry': 5, + }, + }, + ], +} + + +def test_build_task_group_context_manager(): + execution_date = pendulum.parse("20200101") + with DAG("test_build_task_group_context_manager", start_date=execution_date) as dag: + task1 = DummyOperator(task_id="task1") + with TaskGroup("group234") as group234: + _ = DummyOperator(task_id="task2") + + with TaskGroup("group34") as group34: + _ = DummyOperator(task_id="task3") + _ = DummyOperator(task_id="task4") + + task5 = DummyOperator(task_id="task5") + task1 >> group234 + group34 >> task5 + + assert task1.get_direct_relative_ids(upstream=False) == { + 'group234.group34.task4', + 'group234.group34.task3', + 'group234.task2', + } + assert task5.get_direct_relative_ids(upstream=True) == { + 'group234.group34.task4', + 'group234.group34.task3', + } + + assert dag.task_group.group_id is None + assert dag.task_group.is_root + assert set(dag.task_group.children.keys()) == {"task1", "group234", "task5"} + assert group34.group_id == "group234.group34" + + assert task_group_to_dict(dag.task_group) == EXPECTED_JSON + + +def test_build_task_group(): + """ + This is an alternative syntax to use TaskGroup. It should result in the same TaskGroup + as using context manager. + """ + execution_date = pendulum.parse("20200101") + dag = DAG("test_build_task_group", start_date=execution_date) + task1 = DummyOperator(task_id="task1", dag=dag) + group234 = TaskGroup("group234", dag=dag) + _ = DummyOperator(task_id="task2", dag=dag, task_group=group234) + group34 = TaskGroup("group34", dag=dag, parent_group=group234) + _ = DummyOperator(task_id="task3", dag=dag, task_group=group34) + _ = DummyOperator(task_id="task4", dag=dag, task_group=group34) + task5 = DummyOperator(task_id="task5", dag=dag) + + task1 >> group234 + group34 >> task5 + + assert task_group_to_dict(dag.task_group) == EXPECTED_JSON + + +def extract_node_id(node, include_label=False): + ret = {"id": node["id"]} + if include_label: + ret["label"] = node["value"]["label"] + if "children" in node: + children = [] + for child in node["children"]: + children.append(extract_node_id(child, include_label=include_label)) + + ret["children"] = children + + return ret + + +def test_build_task_group_with_prefix(): + """ + Tests that prefix_group_id turns on/off prefixing of task_id with group_id. + """ + execution_date = pendulum.parse("20200101") + with DAG("test_build_task_group_with_prefix", start_date=execution_date) as dag: + task1 = DummyOperator(task_id="task1") + with TaskGroup("group234", prefix_group_id=False) as group234: + task2 = DummyOperator(task_id="task2") + + with TaskGroup("group34") as group34: + task3 = DummyOperator(task_id="task3") + + with TaskGroup("group4", prefix_group_id=False) as group4: + task4 = DummyOperator(task_id="task4") + + task5 = DummyOperator(task_id="task5") + task1 >> group234 + group34 >> task5 + + assert task2.task_id == "task2" + assert group34.group_id == "group34" + assert task3.task_id == "group34.task3" + assert group4.group_id == "group34.group4" + assert task4.task_id == "task4" + assert task5.task_id == "task5" + assert group234.get_child_by_label("task2") == task2 + assert group234.get_child_by_label("group34") == group34 + assert group4.get_child_by_label("task4") == task4 + + assert extract_node_id(task_group_to_dict(dag.task_group), include_label=True) == { + 'id': None, + 'label': None, + 'children': [ + { + 'id': 'group234', + 'label': 'group234', + 'children': [ + { + 'id': 'group34', + 'label': 'group34', + 'children': [ + { + 'id': 'group34.group4', + 'label': 'group4', + 'children': [{'id': 'task4', 'label': 'task4'}], + }, + {'id': 'group34.task3', 'label': 'task3'}, + {'id': 'group34.downstream_join_id', 'label': ''}, + ], + }, + {'id': 'task2', 'label': 'task2'}, + {'id': 'group234.upstream_join_id', 'label': ''}, + ], + }, + {'id': 'task1', 'label': 'task1'}, + {'id': 'task5', 'label': 'task5'}, + ], + } + + +def test_build_task_group_with_task_decorator(): + """ + Test that TaskGroup can be used with the @task decorator. + """ + from airflow.operators.python import task + + @task + def task_1(): + print("task_1") + + @task + def task_2(): + return "task_2" + + @task + def task_3(): + return "task_3" + + @task + def task_4(task_2_output, task_3_output): + print(task_2_output, task_3_output) + + @task + def task_5(): + print("task_5") + + execution_date = pendulum.parse("20200101") + with DAG("test_build_task_group_with_task_decorator", start_date=execution_date) as dag: + tsk_1 = task_1() + + with TaskGroup("group234") as group234: + tsk_2 = task_2() + tsk_3 = task_3() + tsk_4 = task_4(tsk_2, tsk_3) + + tsk_5 = task_5() + + tsk_1 >> group234 >> tsk_5 + + # pylint: disable=no-member + assert tsk_1.operator in tsk_2.operator.upstream_list + assert tsk_1.operator in tsk_3.operator.upstream_list + assert tsk_5.operator in tsk_4.operator.downstream_list + # pylint: enable=no-member + + assert extract_node_id(task_group_to_dict(dag.task_group)) == { + 'id': None, + 'children': [ + { + 'id': 'group234', + 'children': [ + {'id': 'group234.task_2'}, + {'id': 'group234.task_3'}, + {'id': 'group234.task_4'}, + {'id': 'group234.upstream_join_id'}, + {'id': 'group234.downstream_join_id'}, + ], + }, + {'id': 'task_1'}, + {'id': 'task_5'}, + ], + } + + edges = dag_edges(dag) + assert sorted((e["source_id"], e["target_id"]) for e in edges) == [ + ('group234.downstream_join_id', 'task_5'), + ('group234.task_2', 'group234.task_4'), + ('group234.task_3', 'group234.task_4'), + ('group234.task_4', 'group234.downstream_join_id'), + ('group234.upstream_join_id', 'group234.task_2'), + ('group234.upstream_join_id', 'group234.task_3'), + ('task_1', 'group234.upstream_join_id'), + ] + + +def test_sub_dag_task_group(): + """ + Tests dag.sub_dag() updates task_group correctly. + """ + execution_date = pendulum.parse("20200101") + with DAG("test_test_task_group_sub_dag", start_date=execution_date) as dag: + task1 = DummyOperator(task_id="task1") + with TaskGroup("group234") as group234: + _ = DummyOperator(task_id="task2") + + with TaskGroup("group34") as group34: + _ = DummyOperator(task_id="task3") + _ = DummyOperator(task_id="task4") + + with TaskGroup("group6") as group6: + _ = DummyOperator(task_id="task6") + + task7 = DummyOperator(task_id="task7") + task5 = DummyOperator(task_id="task5") + + task1 >> group234 + group34 >> task5 + group234 >> group6 + group234 >> task7 + + subdag = dag.sub_dag(task_regex="task5", include_upstream=True, include_downstream=False) + + assert extract_node_id(task_group_to_dict(subdag.task_group)) == { + 'id': None, + 'children': [ + { + 'id': 'group234', + 'children': [ + { + 'id': 'group234.group34', + 'children': [ + {'id': 'group234.group34.task3'}, + {'id': 'group234.group34.task4'}, + {'id': 'group234.group34.downstream_join_id'}, + ], + }, + {'id': 'group234.upstream_join_id'}, + ], + }, + {'id': 'task1'}, + {'id': 'task5'}, + ], + } + + edges = dag_edges(subdag) + assert sorted((e["source_id"], e["target_id"]) for e in edges) == [ + ('group234.group34.downstream_join_id', 'task5'), + ('group234.group34.task3', 'group234.group34.downstream_join_id'), + ('group234.group34.task4', 'group234.group34.downstream_join_id'), + ('group234.upstream_join_id', 'group234.group34.task3'), + ('group234.upstream_join_id', 'group234.group34.task4'), + ('task1', 'group234.upstream_join_id'), + ] + + subdag_task_groups = subdag.task_group.get_task_group_dict() + assert subdag_task_groups.keys() == {None, "group234", "group234.group34"} + + included_group_ids = {"group234", "group234.group34"} + included_task_ids = {'group234.group34.task3', 'group234.group34.task4', 'task1', 'task5'} + + for task_group in subdag_task_groups.values(): + assert task_group.upstream_group_ids.issubset(included_group_ids) + assert task_group.downstream_group_ids.issubset(included_group_ids) + assert task_group.upstream_task_ids.issubset(included_task_ids) + assert task_group.downstream_task_ids.issubset(included_task_ids) + + for task in subdag.task_group: + assert task.upstream_task_ids.issubset(included_task_ids) + assert task.downstream_task_ids.issubset(included_task_ids) + + +def test_dag_edges(): + execution_date = pendulum.parse("20200101") + with DAG("test_dag_edges", start_date=execution_date) as dag: + task1 = DummyOperator(task_id="task1") + with TaskGroup("group_a") as group_a: + with TaskGroup("group_b") as group_b: + task2 = DummyOperator(task_id="task2") + task3 = DummyOperator(task_id="task3") + task4 = DummyOperator(task_id="task4") + task2 >> [task3, task4] + + task5 = DummyOperator(task_id="task5") + + task5 << group_b + + task1 >> group_a + + with TaskGroup("group_c") as group_c: + task6 = DummyOperator(task_id="task6") + task7 = DummyOperator(task_id="task7") + task8 = DummyOperator(task_id="task8") + [task6, task7] >> task8 + group_a >> group_c + + task5 >> task8 + + task9 = DummyOperator(task_id="task9") + task10 = DummyOperator(task_id="task10") + + group_c >> [task9, task10] + + with TaskGroup("group_d") as group_d: + task11 = DummyOperator(task_id="task11") + task12 = DummyOperator(task_id="task12") + task11 >> task12 + + group_d << group_c + + nodes = task_group_to_dict(dag.task_group) + edges = dag_edges(dag) + + assert extract_node_id(nodes) == { + 'id': None, + 'children': [ + { + 'id': 'group_a', + 'children': [ + { + 'id': 'group_a.group_b', + 'children': [ + {'id': 'group_a.group_b.task2'}, + {'id': 'group_a.group_b.task3'}, + {'id': 'group_a.group_b.task4'}, + {'id': 'group_a.group_b.downstream_join_id'}, + ], + }, + {'id': 'group_a.task5'}, + {'id': 'group_a.upstream_join_id'}, + {'id': 'group_a.downstream_join_id'}, + ], + }, + { + 'id': 'group_c', + 'children': [ + {'id': 'group_c.task6'}, + {'id': 'group_c.task7'}, + {'id': 'group_c.task8'}, + {'id': 'group_c.upstream_join_id'}, + {'id': 'group_c.downstream_join_id'}, + ], + }, + { + 'id': 'group_d', + 'children': [ + {'id': 'group_d.task11'}, + {'id': 'group_d.task12'}, + {'id': 'group_d.upstream_join_id'}, + ], + }, + {'id': 'task1'}, + {'id': 'task10'}, + {'id': 'task9'}, + ], + } + + assert sorted((e["source_id"], e["target_id"]) for e in edges) == [ + ('group_a.downstream_join_id', 'group_c.upstream_join_id'), + ('group_a.group_b.downstream_join_id', 'group_a.task5'), + ('group_a.group_b.task2', 'group_a.group_b.task3'), + ('group_a.group_b.task2', 'group_a.group_b.task4'), + ('group_a.group_b.task3', 'group_a.group_b.downstream_join_id'), + ('group_a.group_b.task4', 'group_a.group_b.downstream_join_id'), + ('group_a.task5', 'group_a.downstream_join_id'), + ('group_a.task5', 'group_c.task8'), + ('group_a.upstream_join_id', 'group_a.group_b.task2'), + ('group_c.downstream_join_id', 'group_d.upstream_join_id'), + ('group_c.downstream_join_id', 'task10'), + ('group_c.downstream_join_id', 'task9'), + ('group_c.task6', 'group_c.task8'), + ('group_c.task7', 'group_c.task8'), + ('group_c.task8', 'group_c.downstream_join_id'), + ('group_c.upstream_join_id', 'group_c.task6'), + ('group_c.upstream_join_id', 'group_c.task7'), + ('group_d.task11', 'group_d.task12'), + ('group_d.upstream_join_id', 'group_d.task11'), + ('task1', 'group_a.upstream_join_id'), + ] + + +def test_duplicate_group_id(): + from airflow.exceptions import DuplicateTaskIdFound + + execution_date = pendulum.parse("20200101") + + with pytest.raises(DuplicateTaskIdFound, match=r".* 'task1' .*"): + with DAG("test_duplicate_group_id", start_date=execution_date): + _ = DummyOperator(task_id="task1") + with TaskGroup("task1"): + pass + + with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"): + with DAG("test_duplicate_group_id", start_date=execution_date): + _ = DummyOperator(task_id="task1") + with TaskGroup("group1", prefix_group_id=False): + with TaskGroup("group1"): + pass + + with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1' .*"): + with DAG("test_duplicate_group_id", start_date=execution_date): + with TaskGroup("group1", prefix_group_id=False): + _ = DummyOperator(task_id="group1") + + with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1.downstream_join_id' .*"): + with DAG("test_duplicate_group_id", start_date=execution_date): + _ = DummyOperator(task_id="task1") + with TaskGroup("group1"): + _ = DummyOperator(task_id="downstream_join_id") + + with pytest.raises(DuplicateTaskIdFound, match=r".* 'group1.upstream_join_id' .*"): + with DAG("test_duplicate_group_id", start_date=execution_date): + _ = DummyOperator(task_id="task1") + with TaskGroup("group1"): + _ = DummyOperator(task_id="upstream_join_id")