From c8d0b0e1fb037345855598042d426c6ccc9968b3 Mon Sep 17 00:00:00 2001 From: penggongkui <49138427+penggongkui@users.noreply.github.com> Date: Wed, 23 Jun 2021 13:13:04 +0800 Subject: [PATCH] TaskGroup add default_args (#16557) * TaskGroup add default_args * test case && pylint * TaskGroup default_args docs * Update docs/apache-airflow/concepts/dags.rst Co-authored-by: Xinbin Huang Co-authored-by: Xinbin Huang --- airflow/models/baseoperator.py | 4 ++++ airflow/utils/task_group.py | 11 +++++++++++ docs/apache-airflow/concepts/dags.rst | 9 +++++++++ tests/utils/test_task_group.py | 21 +++++++++++++++++++++ 4 files changed, 45 insertions(+) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 7af23d3ed26f2..57b561c24099f 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -136,6 +136,7 @@ def warn(self, message, category=None, stacklevel=1, source=None): @functools.wraps(func) def apply_defaults(self, *args: Any, **kwargs: Any) -> Any: from airflow.models.dag import DagContext + from airflow.utils.task_group import TaskGroupContext if len(args) > 0: raise AirflowException("Use keyword arguments when initializing operators") @@ -146,6 +147,9 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any: if dag: dag_args = copy.copy(dag.default_args) or {} dag_params = copy.copy(dag.params) or {} + task_group = TaskGroupContext.get_current_task_group(dag) + if task_group: + dag_args.update(task_group.default_args) params = kwargs.get('params', {}) or {} dag_params.update(params) diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 01a7132631801..71c2ae4addfbf 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -19,6 +19,7 @@ A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped together when the DAG is displayed graphically. """ +import copy import re from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union @@ -48,6 +49,14 @@ class TaskGroup(TaskMixin): :type parent_group: TaskGroup :param dag: The DAG that this TaskGroup belongs to. :type dag: airflow.models.DAG + :param default_args: A dictionary of default parameters to be used + as constructor keyword parameters when initialising operators, + will override default_args defined in the DAG level. + Note that operators have the same hook, and precede those defined + here, meaning that if your dict contains `'depends_on_past': True` + here and `'depends_on_past': False` in the operator's call + `default_args`, the actual value will be `False`. + :type default_args: dict :param tooltip: The tooltip of the TaskGroup node when displayed in the UI :type tooltip: str :param ui_color: The fill color of the TaskGroup node when displayed in the UI @@ -65,6 +74,7 @@ def __init__( prefix_group_id: bool = True, parent_group: Optional["TaskGroup"] = None, dag: Optional["DAG"] = None, + default_args: Optional[Dict] = None, tooltip: str = "", ui_color: str = "CornflowerBlue", ui_fgcolor: str = "#000", @@ -73,6 +83,7 @@ def __init__( from airflow.models.dag import DagContext self.prefix_group_id = prefix_group_id + self.default_args = copy.deepcopy(default_args or {}) if group_id is None: # This creates a root TaskGroup. diff --git a/docs/apache-airflow/concepts/dags.rst b/docs/apache-airflow/concepts/dags.rst index c250e0c25820a..89bde8bfcf589 100644 --- a/docs/apache-airflow/concepts/dags.rst +++ b/docs/apache-airflow/concepts/dags.rst @@ -441,6 +441,15 @@ Dependency relationships can be applied across all tasks in a TaskGroup with the group1 >> task3 +TaskGroup also supports ``default_args`` like DAG, it will overwrite the ``default_args`` in DAG level:: + + with DAG(dag_id='dag1', default_args={'start_date': datetime(2016, 1, 1), 'owner': 'dag'}): + with TaskGroup('group1', default_args={'owner': 'group'}): + task1 = DummyOperator(task_id='task1') + task2 = DummyOperator(task_id='task2', owner='task2') + print(task1.owner) # "group" + print(task2.owner) # "task2" + If you want to see a more advanced use of TaskGroup, you can look at the ``example_task_group.py`` example DAG that comes with Airflow. .. note:: diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 2574ed6905674..cd84aee49ea57 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -823,6 +823,27 @@ def section_2(value): assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids +def test_default_args(): + """Testing TaskGroup with default_args""" + + execution_date = pendulum.parse("20201109") + with DAG( + dag_id='example_task_group_default_args', + start_date=execution_date, + default_args={ + "owner": "dag", + }, + ): + with TaskGroup("group1", default_args={"owner": "group"}): + task_1 = DummyOperator(task_id='task_1') + task_2 = DummyOperator(task_id='task_2', owner='task') + task_3 = DummyOperator(task_id='task_3', default_args={"owner": "task"}) + + assert task_1.owner == 'group' + assert task_2.owner == 'task' + assert task_3.owner == 'task' + + def test_duplicate_task_group_id(): """Testing automatic suffix assignment for duplicate group_id"""