Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
cce50ec
Add TaskGroup and example
yuqian90 Aug 4, 2020
ad8e2b2
Fix test
yuqian90 Aug 4, 2020
1f0cea0
Remove circular reference between TaskGroup and BaseOperator to make …
yuqian90 Aug 5, 2020
2dcfd33
Add label
yuqian90 Aug 10, 2020
3ad1c62
Add build_task_group
yuqian90 Aug 10, 2020
4ec28d7
test_build_task_group
yuqian90 Aug 10, 2020
6ea4cde
task_group_to_dict
yuqian90 Aug 11, 2020
45afedb
Basic UI changes for TaskGroup
yuqian90 Aug 11, 2020
4e8a9c3
Minor fixes
yuqian90 Aug 11, 2020
8090264
TaskGroup serialization
yuqian90 Aug 12, 2020
14023ab
Fix search box and zoom
yuqian90 Aug 14, 2020
96281c2
Add tooltip to TaskGroup
yuqian90 Aug 16, 2020
866b6c2
Typing and lint
yuqian90 Aug 16, 2020
580384b
Made some attributes private
yuqian90 Aug 17, 2020
1888bd4
Hide tooltip when clicked
yuqian90 Aug 17, 2020
75de22b
Ignore drag when clicked
yuqian90 Aug 17, 2020
3dd1f98
Remove task_group_id
yuqian90 Aug 18, 2020
8dec3e8
Address comments
yuqian90 Aug 18, 2020
55ed52f
Use literal dict to create serialize_group
yuqian90 Aug 18, 2020
ca24002
Reduce size of example dag
yuqian90 Aug 19, 2020
e39e203
Adding TaskGroup to concepts.rst
yuqian90 Aug 19, 2020
c30668a
Change opacity
yuqian90 Aug 19, 2020
0c9ec6a
Typo
yuqian90 Aug 19, 2020
587ee03
Reduce number of edges between child nodes when TaskGroup depends on …
yuqian90 Aug 21, 2020
55d052b
Do not prefix group_id
yuqian90 Aug 29, 2020
2b544f4
Minor fixes
yuqian90 Sep 1, 2020
02dff21
Support both prefix and not prefix group_id
yuqian90 Sep 2, 2020
0317413
Simplify example_task_group
yuqian90 Sep 8, 2020
01ccf02
Address comments (add type hints)
yuqian90 Sep 8, 2020
d6fa4c3
Add test and type hints for XComArg
yuqian90 Sep 9, 2020
cffe049
Put TaskGroup import under TYPE_CHECKING
yuqian90 Sep 11, 2020
9157991
Update test_build_task_group_with_task_decorator now #10827 is merged
yuqian90 Sep 16, 2020
4c9464a
Use TaskMixin(#10930)
yuqian90 Sep 17, 2020
85c9f67
Fix dag.sub_dag not copying task_group tasks bug
yuqian90 Sep 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions airflow/example_dags/example_task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Example DAG demonstrating the usage of the TaskGroup."""

from airflow.models.dag import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.dates import days_ago
from airflow.utils.task_group import TaskGroup

# [START howto_task_group]
with DAG(dag_id="example_task_group", start_date=days_ago(2)) as dag:
start = DummyOperator(task_id="start")

# [START howto_task_group_section_1]
with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1:
task_1 = DummyOperator(task_id="task_1")
task_2 = DummyOperator(task_id="task_2")
task_3 = DummyOperator(task_id="task_3")

task_1 >> [task_2, task_3]
# [END howto_task_group_section_1]

# [START howto_task_group_section_2]
with TaskGroup("section_2", tooltip="Tasks for section_2") as section_2:
task_1 = DummyOperator(task_id="task_1")

# [START howto_task_group_inner_section_2]
with TaskGroup("inner_section_2", tooltip="Tasks for inner_section2") as inner_section_2:
task_2 = DummyOperator(task_id="task_2")
task_3 = DummyOperator(task_id="task_3")
task_4 = DummyOperator(task_id="task_4")

[task_2, task_3] >> task_4
# [END howto_task_group_inner_section_2]

# [END howto_task_group_section_2]

end = DummyOperator(task_id='end')

start >> section_1 >> section_2 >> end
# [END howto_task_group]
34 changes: 25 additions & 9 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta
from typing import (
Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple, Type, Union,
TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Sequence, Set, Tuple,
Type, Union,
)

import attr
Expand Down Expand Up @@ -58,6 +59,9 @@
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule

if TYPE_CHECKING:
from airflow.utils.task_group import TaskGroup # pylint: disable=cyclic-import

ScheduleInterval = Union[str, timedelta, relativedelta]

TaskStateChangeCallback = Callable[[Context], None]
Expand Down Expand Up @@ -360,9 +364,12 @@ def __init__(
do_xcom_push: bool = True,
inlets: Optional[Any] = None,
outlets: Optional[Any] = None,
task_group: Optional["TaskGroup"] = None,
**kwargs
):
from airflow.models.dag import DagContext
from airflow.utils.task_group import TaskGroupContext

super().__init__()
if kwargs:
if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
Expand All @@ -382,6 +389,11 @@ def __init__(
)
validate_key(task_id)
self.task_id = task_id
self.label = task_id
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
if task_group:
self.task_id = task_group.child_id(task_id)
task_group.add(self)
self.owner = owner
self.email = email
self.email_on_retry = email_on_retry
Expand Down Expand Up @@ -609,7 +621,7 @@ def dag(self, dag: Any):
elif self.task_id in dag.task_dict and dag.task_dict[self.task_id] is not self:
dag.add_task(self)

self._dag = dag # pylint: disable=attribute-defined-outside-init
self._dag = dag

def has_dag(self):
"""
Expand Down Expand Up @@ -1120,21 +1132,25 @@ def roots(self) -> List["BaseOperator"]:
"""Required by TaskMixin"""
return [self]

@property
def leaves(self) -> List["BaseOperator"]:
"""Required by TaskMixin"""
return [self]

def _set_relatives(
self,
task_or_task_list: Union[TaskMixin, Sequence[TaskMixin]],
upstream: bool = False,
) -> None:
"""Sets relatives for the task or task list."""

if isinstance(task_or_task_list, Sequence):
task_like_object_list = task_or_task_list
else:
task_like_object_list = [task_or_task_list]
if not isinstance(task_or_task_list, Sequence):
task_or_task_list = [task_or_task_list]

task_list: List["BaseOperator"] = []
for task_object in task_like_object_list:
task_list.extend(task_object.roots)
for task_object in task_or_task_list:
task_object.update_relative(self, not upstream)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @turbaszek I added a update_relative method to TaskMixin that defaults to no-op. It's called here. TaskGroup overrides this method in order to keep track of its direct upstream/downstream TaskGroup or BaseOperator for UI optimization. BaseOperator and XComArg don't need to override it.

The following is the roots vs leaves distinction that I mentioned on the TaskMixin PR. I use it here so that we don't need to hardcode another if isinstace(task_object, TaskGroup) here. Please see if this looks okay to you.

relatives = task_object.leaves if upstream else task_object.roots
task_list.extend(relatives)

for task in task_list:
if not isinstance(task, BaseOperator):
Expand Down
51 changes: 47 additions & 4 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import warnings
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast
from typing import (
TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast,
)

import jinja2
import pendulum
Expand Down Expand Up @@ -59,6 +61,9 @@
from airflow.utils.state import State
from airflow.utils.types import DagRunType

if TYPE_CHECKING:
from airflow.utils.task_group import TaskGroup

log = logging.getLogger(__name__)

ScheduleInterval = Union[str, timedelta, relativedelta]
Expand Down Expand Up @@ -238,6 +243,8 @@ def __init__(
jinja_environment_kwargs: Optional[Dict] = None,
tags: Optional[List[str]] = None
):
from airflow.utils.task_group import TaskGroup

self.user_defined_macros = user_defined_macros
self.user_defined_filters = user_defined_filters
self.default_args = copy.deepcopy(default_args or {})
Expand Down Expand Up @@ -329,6 +336,7 @@ def __init__(

self.jinja_environment_kwargs = jinja_environment_kwargs
self.tags = tags
self._task_group = TaskGroup.create_root(self)

def __repr__(self):
return "<DAG: {self.dag_id}>".format(self=self)
Expand Down Expand Up @@ -570,6 +578,10 @@ def tasks(self, val):
def task_ids(self) -> List[str]:
return list(self.task_dict.keys())

@property
def task_group(self) -> "TaskGroup":
return self._task_group

@property
def filepath(self) -> str:
"""
Expand Down Expand Up @@ -1240,7 +1252,6 @@ def sub_dag(self, task_regex, include_downstream=False,
based on a regex that should match one or many tasks, and includes
upstream and downstream neighbours based on the flag passed.
"""

# deep-copying self.task_dict takes a long time, and we don't want all
# the tasks anyway, so we copy the tasks manually later
task_dict = self.task_dict
Expand All @@ -1261,9 +1272,38 @@ def sub_dag(self, task_regex, include_downstream=False,
# Make sure to not recursively deepcopy the dag while copying the task
dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag})
for t in regex_match + also_include}

# Remove tasks not included in the subdag from task_group
def remove_excluded(group):
for child in list(group.children.values()):
if isinstance(child, BaseOperator):
if child.task_id not in dag.task_dict:
group.children.pop(child.task_id)
else:
# The tasks in the subdag are a copy of tasks in the original dag
# so update the reference in the TaskGroups too.
group.children[child.task_id] = dag.task_dict[child.task_id]
else:
remove_excluded(child)

# Remove this TaskGroup if it doesn't contain any tasks in this subdag
if not child.children:
group.children.pop(child.group_id)

remove_excluded(dag.task_group)

# Removing upstream/downstream references to tasks and TaskGroups that did not make
# the cut.
subdag_task_groups = dag.task_group.get_task_group_dict()
for group in subdag_task_groups.values():
group.upstream_group_ids = group.upstream_group_ids.intersection(subdag_task_groups.keys())
group.downstream_group_ids = group.downstream_group_ids.intersection(subdag_task_groups.keys())
group.upstream_task_ids = group.upstream_task_ids.intersection(dag.task_dict.keys())
group.downstream_task_ids = group.downstream_task_ids.intersection(dag.task_dict.keys())

for t in dag.tasks:
# Removing upstream/downstream references to tasks that did not
# made the cut
# make the cut
t._upstream_task_ids = t.upstream_task_ids.intersection(dag.task_dict.keys())
t._downstream_task_ids = t.downstream_task_ids.intersection(
dag.task_dict.keys())
Expand Down Expand Up @@ -1357,12 +1397,15 @@ def add_task(self, task):
elif task.end_date and self.end_date:
task.end_date = min(task.end_date, self.end_date)

if task.task_id in self.task_dict and self.task_dict[task.task_id] is not task:
if ((task.task_id in self.task_dict and self.task_dict[task.task_id] is not task)
or task.task_id in self._task_group.used_group_ids):
raise DuplicateTaskIdFound(
"Task id '{}' has already been added to the DAG".format(task.task_id))
else:
self.task_dict[task.task_id] = task
task.dag = self
# Add task_id to used_group_ids to prevent group_id and task_id collisions.
self._task_group.used_group_ids.add(task.task_id)

self.task_count = len(self.task_dict)

Expand Down
11 changes: 11 additions & 0 deletions airflow/models/taskmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def roots(self):
"""Should return list of root operator List[BaseOperator]"""
raise NotImplementedError()

@property
def leaves(self):
"""Should return list of leaf operator List[BaseOperator]"""
raise NotImplementedError()

@abstractmethod
def set_upstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
Expand All @@ -47,6 +52,12 @@ def set_downstream(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
raise NotImplementedError()

def update_relative(self, other: "TaskMixin", upstream=True) -> None:
"""
Update relationship information about another TaskMixin. Default is no-op.
Override if necessary.
"""

def __lshift__(self, other: Union["TaskMixin", Sequence["TaskMixin"]]):
"""
Implements Task << Task
Expand Down
5 changes: 5 additions & 0 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ def roots(self) -> List[BaseOperator]:
"""Required by TaskMixin"""
return [self._operator]

@property
def leaves(self) -> List[BaseOperator]:
"""Required by TaskMixin"""
return [self._operator]

@property
def key(self) -> str:
"""Returns keys of this XComArg"""
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,4 @@ class DagAttributeTypes(str, Enum):
SET = 'set'
TUPLE = 'tuple'
POD = 'k8s.V1Pod'
TASK_GROUP = 'taskgroup'
48 changes: 47 additions & 1 deletion airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@
"_default_view": { "type" : "string"},
"_access_control": {"$ref": "#/definitions/dict" },
"is_paused_upon_creation": { "type": "boolean" },
"tags": { "type": "array" }
"tags": { "type": "array" },
"_task_group": {"anyOf": [
{ "type": "null" },
{ "$ref": "#/definitions/task_group" }
]}
},
"required": [
"_dag_id",
Expand Down Expand Up @@ -125,6 +129,7 @@
"_task_module": { "type": "string" },
"_operator_extra_links": { "$ref": "#/definitions/extra_links" },
"task_id": { "type": "string" },
"label": { "type": "string" },
"owner": { "type": "string" },
"start_date": { "$ref": "#/definitions/datetime" },
"end_date": { "$ref": "#/definitions/datetime" },
Expand Down Expand Up @@ -156,6 +161,47 @@
}
},
"additionalProperties": true
},
"task_group": {
"$comment": "A TaskGroup containing tasks",
"type": "object",
"required": [
"_group_id",
"prefix_group_id",
"children",
"tooltip",
"ui_color",
"ui_fgcolor",
"upstream_group_ids",
"downstream_group_ids",
"upstream_task_ids",
"downstream_task_ids"
],
"properties": {
"_group_id": {"anyOf": [{"type": "null"}, { "type": "string" }]},
"prefix_group_id": { "type": "boolean" },
"children": { "$ref": "#/definitions/dict" },
"tooltip": { "type": "string" },
"ui_color": { "type": "string" },
"ui_fgcolor": { "type": "string" },
"upstream_group_ids": {
"type": "array",
"items": { "type": "string" }
},
"downstream_group_ids": {
"type": "array",
"items": { "type": "string" }
},
"upstream_task_ids": {
"type": "array",
"items": { "type": "string" }
},
"downstream_task_ids": {
"type": "array",
"items": { "type": "string" }
}
},
"additionalProperties": false
}
},

Expand Down
Loading