Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TaskGroup add default_args #16557

Merged
merged 5 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def warn(self, message, category=None, stacklevel=1, source=None):
@functools.wraps(func)
def apply_defaults(self, *args: Any, **kwargs: Any) -> Any:
from airflow.models.dag import DagContext
from airflow.utils.task_group import TaskGroupContext

if len(args) > 0:
raise AirflowException("Use keyword arguments when initializing operators")
Expand All @@ -146,6 +147,9 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any:
if dag:
dag_args = copy.copy(dag.default_args) or {}
dag_params = copy.copy(dag.params) or {}
task_group = TaskGroupContext.get_current_task_group(dag)
if task_group:
dag_args.update(task_group.default_args)

params = kwargs.get('params', {}) or {}
dag_params.update(params)
Expand Down
11 changes: 11 additions & 0 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped
together when the DAG is displayed graphically.
"""
import copy
import re
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union

Expand Down Expand Up @@ -48,6 +49,14 @@ class TaskGroup(TaskMixin):
:type parent_group: TaskGroup
:param dag: The DAG that this TaskGroup belongs to.
:type dag: airflow.models.DAG
:param default_args: A dictionary of default parameters to be used
as constructor keyword parameters when initialising operators,
will override default_args defined in the DAG level.
Note that operators have the same hook, and precede those defined
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a note here that this will overwrite the default_args defined in the DAG level?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for suggestion, added it

here, meaning that if your dict contains `'depends_on_past': True`
here and `'depends_on_past': False` in the operator's call
`default_args`, the actual value will be `False`.
:type default_args: dict
:param tooltip: The tooltip of the TaskGroup node when displayed in the UI
:type tooltip: str
:param ui_color: The fill color of the TaskGroup node when displayed in the UI
Expand All @@ -65,6 +74,7 @@ def __init__(
prefix_group_id: bool = True,
parent_group: Optional["TaskGroup"] = None,
dag: Optional["DAG"] = None,
default_args: Optional[Dict] = None,
tooltip: str = "",
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
Expand All @@ -73,6 +83,7 @@ def __init__(
from airflow.models.dag import DagContext

self.prefix_group_id = prefix_group_id
self.default_args = copy.deepcopy(default_args or {})

if group_id is None:
# This creates a root TaskGroup.
Expand Down
9 changes: 9 additions & 0 deletions docs/apache-airflow/concepts/dags.rst
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,15 @@ Dependency relationships can be applied across all tasks in a TaskGroup with the

group1 >> task3

TaskGroup also supports ``default_args`` like DAG, it will overwrite the ``default_args`` in DAG level::

with DAG(dag_id='dag1', default_args={'start_date': datetime(2016, 1, 1), 'owner': 'dag'}):
with TaskGroup('group1', default_args={'owner': 'group'}):
task1 = DummyOperator(task_id='task1')
task2 = DummyOperator(task_id='task2', owner='task2')
print(task1.owner) # "group"
print(task2.owner) # "task2"

If you want to see a more advanced use of TaskGroup, you can look at the ``example_task_group.py`` example DAG that comes with Airflow.

.. note::
Expand Down
21 changes: 21 additions & 0 deletions tests/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,27 @@ def section_2(value):
assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids


def test_default_args():
"""Testing TaskGroup with default_args"""

execution_date = pendulum.parse("20201109")
with DAG(
dag_id='example_task_group_default_args',
start_date=execution_date,
default_args={
"owner": "dag",
},
):
with TaskGroup("group1", default_args={"owner": "group"}):
task_1 = DummyOperator(task_id='task_1')
task_2 = DummyOperator(task_id='task_2', owner='task')
task_3 = DummyOperator(task_id='task_3', default_args={"owner": "task"})

assert task_1.owner == 'group'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it would be better to add additional test verifying that setting a parameter on the task level overwrite the default args of the TaskGroup.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, added some test case

assert task_2.owner == 'task'
assert task_3.owner == 'task'


def test_duplicate_task_group_id():
"""Testing automatic suffix assignment for duplicate group_id"""

Expand Down