Skip to content

Commit

Permalink
TaskGroup add default_args (#16557)
Browse files Browse the repository at this point in the history
* TaskGroup add default_args

* test case && pylint

* TaskGroup default_args docs

* Update docs/apache-airflow/concepts/dags.rst

Co-authored-by: Xinbin Huang <bin.huangxb@gmail.com>

Co-authored-by: Xinbin Huang <bin.huangxb@gmail.com>
  • Loading branch information
penggongkui and xinbinhuang authored Jun 23, 2021
1 parent c5e9141 commit c8d0b0e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 0 deletions.
4 changes: 4 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def warn(self, message, category=None, stacklevel=1, source=None):
@functools.wraps(func)
def apply_defaults(self, *args: Any, **kwargs: Any) -> Any:
from airflow.models.dag import DagContext
from airflow.utils.task_group import TaskGroupContext

if len(args) > 0:
raise AirflowException("Use keyword arguments when initializing operators")
Expand All @@ -146,6 +147,9 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any:
if dag:
dag_args = copy.copy(dag.default_args) or {}
dag_params = copy.copy(dag.params) or {}
task_group = TaskGroupContext.get_current_task_group(dag)
if task_group:
dag_args.update(task_group.default_args)

params = kwargs.get('params', {}) or {}
dag_params.update(params)
Expand Down
11 changes: 11 additions & 0 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped
together when the DAG is displayed graphically.
"""
import copy
import re
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union

Expand Down Expand Up @@ -48,6 +49,14 @@ class TaskGroup(TaskMixin):
:type parent_group: TaskGroup
:param dag: The DAG that this TaskGroup belongs to.
:type dag: airflow.models.DAG
:param default_args: A dictionary of default parameters to be used
as constructor keyword parameters when initialising operators,
will override default_args defined in the DAG level.
Note that operators have the same hook, and precede those defined
here, meaning that if your dict contains `'depends_on_past': True`
here and `'depends_on_past': False` in the operator's call
`default_args`, the actual value will be `False`.
:type default_args: dict
:param tooltip: The tooltip of the TaskGroup node when displayed in the UI
:type tooltip: str
:param ui_color: The fill color of the TaskGroup node when displayed in the UI
Expand All @@ -65,6 +74,7 @@ def __init__(
prefix_group_id: bool = True,
parent_group: Optional["TaskGroup"] = None,
dag: Optional["DAG"] = None,
default_args: Optional[Dict] = None,
tooltip: str = "",
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
Expand All @@ -73,6 +83,7 @@ def __init__(
from airflow.models.dag import DagContext

self.prefix_group_id = prefix_group_id
self.default_args = copy.deepcopy(default_args or {})

if group_id is None:
# This creates a root TaskGroup.
Expand Down
9 changes: 9 additions & 0 deletions docs/apache-airflow/concepts/dags.rst
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,15 @@ Dependency relationships can be applied across all tasks in a TaskGroup with the

group1 >> task3

TaskGroup also supports ``default_args`` like DAG, it will overwrite the ``default_args`` in DAG level::

with DAG(dag_id='dag1', default_args={'start_date': datetime(2016, 1, 1), 'owner': 'dag'}):
with TaskGroup('group1', default_args={'owner': 'group'}):
task1 = DummyOperator(task_id='task1')
task2 = DummyOperator(task_id='task2', owner='task2')
print(task1.owner) # "group"
print(task2.owner) # "task2"

If you want to see a more advanced use of TaskGroup, you can look at the ``example_task_group.py`` example DAG that comes with Airflow.

.. note::
Expand Down
21 changes: 21 additions & 0 deletions tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,27 @@ def section_2(value):
assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids


def test_default_args():
"""Testing TaskGroup with default_args"""

execution_date = pendulum.parse("20201109")
with DAG(
dag_id='example_task_group_default_args',
start_date=execution_date,
default_args={
"owner": "dag",
},
):
with TaskGroup("group1", default_args={"owner": "group"}):
task_1 = DummyOperator(task_id='task_1')
task_2 = DummyOperator(task_id='task_2', owner='task')
task_3 = DummyOperator(task_id='task_3', default_args={"owner": "task"})

assert task_1.owner == 'group'
assert task_2.owner == 'task'
assert task_3.owner == 'task'


def test_duplicate_task_group_id():
"""Testing automatic suffix assignment for duplicate group_id"""

Expand Down

0 comments on commit c8d0b0e

Please sign in to comment.