From 4cf2a340772a344832a6e69a4fca949f8ef79a9a Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Thu, 17 Jul 2025 12:26:41 +0100 Subject: [PATCH 1/3] Deprecate and move `airflow.utils.task_group` to SDK Some part of the module is already moved to SDK. This completes the move --- .../api_fastapi/core_api/routes/ui/grid.py | 2 +- .../core_api/routes/ui/structure.py | 2 +- .../api_fastapi/core_api/services/ui/grid.py | 3 +- .../example_dags/example_setup_teardown.py | 2 +- .../example_dags/example_task_group.py | 2 +- .../src/airflow/models/taskinstance.py | 3 +- .../airflow/ti_deps/deps/trigger_rule_dep.py | 2 +- airflow-core/src/airflow/utils/__init__.py | 6 + .../src/airflow/utils/dot_renderer.py | 2 +- airflow-core/src/airflow/utils/task_group.py | 128 ------------------ .../core_api/routes/ui/test_grid.py | 2 +- airflow-core/tests/unit/models/test_dagrun.py | 2 +- .../tests/unit/models/test_mappedoperator.py | 2 +- .../tests/unit/models/test_taskinstance.py | 2 +- .../serialization/test_dag_serialization.py | 2 +- .../serialization/test_serialized_objects.py | 2 +- .../tests/unit/utils/test_dag_cycle.py | 2 +- .../tests/unit/utils/test_dot_renderer.py | 2 +- .../tests/unit/utils/test_edgemodifier.py | 2 +- .../tests/unit/utils/test_task_group.py | 2 +- .../src/airflow/sdk/definitions/taskgroup.py | 85 +++++++++--- 21 files changed, 90 insertions(+), 167 deletions(-) delete mode 100644 airflow-core/src/airflow/utils/task_group.py diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py index 5c7494d3b704c..f4c6d7d558cbd 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py @@ -52,7 +52,7 @@ from airflow.models.dagrun import DagRun from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance -from airflow.utils.task_group import ( +from airflow.sdk.definitions.taskgroup import ( get_task_group_children_getter, task_group_to_dict_grid, ) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py index fff7325f41aba..c308ae214322c 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py @@ -32,8 +32,8 @@ ) from airflow.models.dag_version import DagVersion from airflow.models.serialized_dag import SerializedDagModel +from airflow.sdk.definitions.taskgroup import task_group_to_dict from airflow.utils.dag_edges import dag_edges -from airflow.utils.task_group import task_group_to_dict structure_router = AirflowRouter(tags=["Structure"], prefix="/structure") diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py index 3bf22517da246..b8e569f877e70 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py @@ -25,9 +25,8 @@ from airflow.api_fastapi.common.parameters import state_priority from airflow.models.taskmap import TaskMap from airflow.sdk.definitions.mappedoperator import MappedOperator -from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup +from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup, get_task_group_children_getter from airflow.serialization.serialized_objects import SerializedBaseOperator -from airflow.utils.task_group import get_task_group_children_getter log = structlog.get_logger(logger_name=__name__) diff --git a/airflow-core/src/airflow/example_dags/example_setup_teardown.py b/airflow-core/src/airflow/example_dags/example_setup_teardown.py index a36e79a55e5f5..052377736ea59 100644 --- a/airflow-core/src/airflow/example_dags/example_setup_teardown.py +++ b/airflow-core/src/airflow/example_dags/example_setup_teardown.py @@ -23,7 +23,7 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.sdk import DAG -from airflow.utils.task_group import TaskGroup +from airflow.sdk.definitions.taskgroup import TaskGroup with DAG( dag_id="example_setup_teardown", diff --git a/airflow-core/src/airflow/example_dags/example_task_group.py b/airflow-core/src/airflow/example_dags/example_task_group.py index e83ac2e9989cf..c882c269c476b 100644 --- a/airflow-core/src/airflow/example_dags/example_task_group.py +++ b/airflow-core/src/airflow/example_dags/example_task_group.py @@ -24,7 +24,7 @@ from airflow.providers.standard.operators.bash import BashOperator from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import DAG -from airflow.utils.task_group import TaskGroup +from airflow.sdk.definitions.taskgroup import TaskGroup # [START howto_task_group] with DAG( diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 40ad420d43047..ae55f6a65ad11 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -124,11 +124,10 @@ from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import MappedOperator - from airflow.sdk.definitions.taskgroup import MappedTaskGroup + from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.sdk.types import RuntimeTaskInstanceProtocol from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.context import Context - from airflow.utils.task_group import TaskGroup Operator: TypeAlias = BaseOperator | MappedOperator diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py index 9a4f30c8dff09..9a0096f569d5c 100644 --- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py @@ -26,9 +26,9 @@ from sqlalchemy import and_, func, or_, select from airflow.models.taskinstance import PAST_DEPENDS_MET +from airflow.sdk.definitions.taskgroup import MappedTaskGroup from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.state import TaskInstanceState -from airflow.utils.task_group import MappedTaskGroup from airflow.utils.trigger_rule import TriggerRule as TR if TYPE_CHECKING: diff --git a/airflow-core/src/airflow/utils/__init__.py b/airflow-core/src/airflow/utils/__init__.py index 39a213c77c64f..7bd56814052d6 100644 --- a/airflow-core/src/airflow/utils/__init__.py +++ b/airflow-core/src/airflow/utils/__init__.py @@ -46,5 +46,11 @@ def __getattr__(name: str): "xcom": { "XCOM_RETURN_KEY": "airflow.models.xcom.XCOM_RETURN_KEY", }, + "task_group": { + "TaskGroup": "airflow.sdk.definitions.taskgroup.TaskGroup", + "MappedTaskGroup": "airflow.sdk.definitions.taskgroup.MappedOperator", + "get_task_group_children_getter": "airflow.sdk.definitions.taskgroup.get_task_group_children_getter", + "task_group_to_dict": "airflow.sdk.definitions.taskgroup.task_group_to_dict", + }, } add_deprecated_classes(__deprecated_classes, __name__) diff --git a/airflow-core/src/airflow/utils/dot_renderer.py b/airflow-core/src/airflow/utils/dot_renderer.py index 5b624d3053463..50911572d3971 100644 --- a/airflow-core/src/airflow/utils/dot_renderer.py +++ b/airflow-core/src/airflow/utils/dot_renderer.py @@ -26,10 +26,10 @@ from airflow.exceptions import AirflowException from airflow.sdk import BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator +from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.dag_edges import dag_edges from airflow.utils.state import State -from airflow.utils.task_group import TaskGroup if TYPE_CHECKING: import graphviz diff --git a/airflow-core/src/airflow/utils/task_group.py b/airflow-core/src/airflow/utils/task_group.py deleted file mode 100644 index 675f9f19faf97..0000000000000 --- a/airflow-core/src/airflow/utils/task_group.py +++ /dev/null @@ -1,128 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""A collection of closely related tasks on the same DAG that should be grouped together visually.""" - -from __future__ import annotations - -from collections.abc import Callable -from functools import cache -from operator import methodcaller -from typing import TYPE_CHECKING - -import airflow.sdk.definitions.taskgroup -from airflow.configuration import conf - -if TYPE_CHECKING: - from typing import TypeAlias - -TaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.TaskGroup -MappedTaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.MappedTaskGroup - - -@cache -def get_task_group_children_getter() -> Callable: - """Get the Task Group Children Getter for the DAG.""" - sort_order = conf.get("api", "grid_view_sorting_order") - if sort_order == "topological": - return methodcaller("topological_sort") - return methodcaller("hierarchical_alphabetical_sort") - - -def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False): - """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" - from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator - from airflow.sdk.definitions.mappedoperator import MappedOperator - from airflow.serialization.serialized_objects import SerializedBaseOperator - - if isinstance(task := task_item_or_group, (AbstractOperator, SerializedBaseOperator)): - node_operator = { - "id": task.task_id, - "label": task.label, - "operator": task.operator_name, - "type": "task", - } - if task.is_setup: - node_operator["setup_teardown_type"] = "setup" - elif task.is_teardown: - node_operator["setup_teardown_type"] = "teardown" - if isinstance(task, MappedOperator) or parent_group_is_mapped: - node_operator["is_mapped"] = True - return node_operator - - task_group = task_item_or_group - is_mapped = isinstance(task_group, MappedTaskGroup) - children = [ - task_group_to_dict(child, parent_group_is_mapped=parent_group_is_mapped or is_mapped) - for child in get_task_group_children_getter()(task_group) - ] - - if task_group.upstream_group_ids or task_group.upstream_task_ids: - # This is the join node used to reduce the number of edges between two TaskGroup. - children.append({"id": task_group.upstream_join_id, "label": "", "type": "join"}) - - if task_group.downstream_group_ids or task_group.downstream_task_ids: - # This is the join node used to reduce the number of edges between two TaskGroup. - children.append({"id": task_group.downstream_join_id, "label": "", "type": "join"}) - - return { - "id": task_group.group_id, - "label": task_group.label, - "tooltip": task_group.tooltip, - "is_mapped": is_mapped, - "children": children, - "type": "task", - } - - -def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False): - """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" - from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator - from airflow.sdk.definitions.mappedoperator import MappedOperator - from airflow.serialization.serialized_objects import SerializedBaseOperator - - if isinstance(task := task_item_or_group, (AbstractOperator, SerializedBaseOperator)): - is_mapped = None - if isinstance(task, MappedOperator) or parent_group_is_mapped: - is_mapped = True - setup_teardown_type = None - if task.is_setup is True: - setup_teardown_type = "setup" - elif task.is_teardown is True: - setup_teardown_type = "teardown" - return { - "id": task.task_id, - "label": task.label, - "is_mapped": is_mapped, - "children": None, - "setup_teardown_type": setup_teardown_type, - } - - task_group = task_item_or_group - task_group_sort = get_task_group_children_getter() - is_mapped_group = isinstance(task_group, MappedTaskGroup) - children = [ - task_group_to_dict_grid(x, parent_group_is_mapped=parent_group_is_mapped or is_mapped_group) - for x in task_group_sort(task_group) - ] - - return { - "id": task_group.group_id, - "label": task_group.label, - "is_mapped": is_mapped_group or None, - "children": children or None, - } diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py index 314c6bf843752..9db616f4dbde2 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py @@ -29,10 +29,10 @@ from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import task_group +from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.utils import timezone from airflow.utils.session import provide_session from airflow.utils.state import DagRunState, TaskInstanceState -from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunTriggeredByType, DagRunType from tests_common.test_utils.db import clear_db_assets, clear_db_dags, clear_db_runs, clear_db_serialized_dags diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 32445dd295884..50440d4f44ef2 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -2774,7 +2774,7 @@ def test_teardown_and_fail_fast(dag_maker): in this case, the second teardown skips because its setup skips. """ from airflow.sdk import task as task_decorator - from airflow.utils.task_group import TaskGroup + from airflow.sdk.definitions.taskgroup import TaskGroup with dag_maker(fail_fast=True) as dag: for num in (1, 2): diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py b/airflow-core/tests/unit/models/test_mappedoperator.py index 2922ee1c23e50..ea9f35b63f765 100644 --- a/airflow-core/tests/unit/models/test_mappedoperator.py +++ b/airflow-core/tests/unit/models/test_mappedoperator.py @@ -34,8 +34,8 @@ from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk import setup, task, task_group, teardown +from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.utils.state import TaskInstanceState -from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule from tests_common.test_utils.mapping import expand_mapped_task diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index d1b3a750db305..7e041917056ed 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -67,6 +67,7 @@ from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse from airflow.sdk.definitions.asset import Asset, AssetAlias from airflow.sdk.definitions.param import process_params +from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.sdk.execution_time.comms import ( AssetEventsResult, ) @@ -83,7 +84,6 @@ from airflow.utils.session import create_session, provide_session from airflow.utils.span_status import SpanStatus from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunTriggeredByType, DagRunType from tests_common.test_utils import db diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 3e11e91499472..8f2e54d6c275a 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -68,6 +68,7 @@ from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY from airflow.sdk.definitions.asset import Asset, AssetUniqueKey from airflow.sdk.definitions.param import Param, ParamsDict +from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.security import permissions from airflow.serialization.enums import Encoding from airflow.serialization.json_schema import load_dag_schema_dict @@ -84,7 +85,6 @@ from airflow.utils import timezone from airflow.utils.module_loading import qualname from airflow.utils.operator_resources import Resources -from airflow.utils.task_group import TaskGroup from tests_common.test_utils.config import conf_vars from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index 402c169cc08b2..dc944035c75a9 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -59,6 +59,7 @@ from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineAlertFields, DeadlineReference from airflow.sdk.definitions.decorators import task from airflow.sdk.definitions.param import Param +from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.serialized_objects import BaseSerialization, LazyDeserializedDAG, SerializedDAG @@ -68,7 +69,6 @@ from airflow.utils.db import LazySelectSequence from airflow.utils.operator_resources import Resources from airflow.utils.state import DagRunState, State -from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunType from unit.models import DEFAULT_DATE diff --git a/airflow-core/tests/unit/utils/test_dag_cycle.py b/airflow-core/tests/unit/utils/test_dag_cycle.py index c436af01c7dc3..e17ff7c5f3cf2 100644 --- a/airflow-core/tests/unit/utils/test_dag_cycle.py +++ b/airflow-core/tests/unit/utils/test_dag_cycle.py @@ -22,8 +22,8 @@ from airflow.models.dag import DAG from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk import Label +from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.utils.dag_cycle_tester import check_cycle -from airflow.utils.task_group import TaskGroup from unit.models import DEFAULT_DATE diff --git a/airflow-core/tests/unit/utils/test_dot_renderer.py b/airflow-core/tests/unit/utils/test_dot_renderer.py index d3ba7acaab796..240876ec7f43e 100644 --- a/airflow-core/tests/unit/utils/test_dot_renderer.py +++ b/airflow-core/tests/unit/utils/test_dot_renderer.py @@ -24,10 +24,10 @@ from airflow.models.dag import DAG from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.serialization.dag_dependency import DagDependency from airflow.utils import dot_renderer, timezone from airflow.utils.state import State -from airflow.utils.task_group import TaskGroup from tests_common.test_utils.compat import BashOperator from tests_common.test_utils.db import clear_db_dags diff --git a/airflow-core/tests/unit/utils/test_edgemodifier.py b/airflow-core/tests/unit/utils/test_edgemodifier.py index 0885230e113ec..98ea514af3c1d 100644 --- a/airflow-core/tests/unit/utils/test_edgemodifier.py +++ b/airflow-core/tests/unit/utils/test_edgemodifier.py @@ -25,8 +25,8 @@ from airflow.providers.standard.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk import Label +from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.utils.dag_edges import dag_edges -from airflow.utils.task_group import TaskGroup DEFAULT_ARGS = { "owner": "test", diff --git a/airflow-core/tests/unit/utils/test_task_group.py b/airflow-core/tests/unit/utils/test_task_group.py index 8d6e6c0b33430..7af3b3326f40d 100644 --- a/airflow-core/tests/unit/utils/test_task_group.py +++ b/airflow-core/tests/unit/utils/test_task_group.py @@ -34,8 +34,8 @@ task_group as task_group_decorator, teardown, ) +from airflow.sdk.definitions.taskgroup import TaskGroup, task_group_to_dict from airflow.utils.dag_edges import dag_edges -from airflow.utils.task_group import TaskGroup, task_group_to_dict from tests_common.test_utils.compat import BashOperator, PythonOperator from unit.models import DEFAULT_DATE diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index c5afe65e0d839..ed79ae4202ead 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -24,12 +24,15 @@ import operator import re import weakref -from collections.abc import Generator, Iterator, Sequence +from collections.abc import Callable, Generator, Iterator, Sequence +from functools import cache +from operator import methodcaller from typing import TYPE_CHECKING, Any import attrs import methodtools +from airflow.configuration import conf from airflow.exceptions import ( AirflowDagCycleException, AirflowException, @@ -669,36 +672,41 @@ def iter_mapped_dependencies(self) -> Iterator[Operator]: yield op -def task_group_to_dict(task_item_or_group): +@cache +def get_task_group_children_getter() -> Callable: + """Get the Task Group Children Getter for the DAG.""" + sort_order = conf.get("api", "grid_view_sorting_order") + if sort_order == "topological": + return methodcaller("topological_sort") + return methodcaller("hierarchical_alphabetical_sort") + + +def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False): """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.serialization.serialized_objects import SerializedBaseOperator if isinstance(task := task_item_or_group, (AbstractOperator, SerializedBaseOperator)): - setup_teardown_type = {} - is_mapped = {} - node_type = {"type": "task"} - if task.is_setup is True: - setup_teardown_type["setup_teardown_type"] = "setup" - elif task.is_teardown is True: - setup_teardown_type["setup_teardown_type"] = "teardown" - if isinstance(task, MappedOperator): - is_mapped["is_mapped"] = True - if getattr(task, "_is_sensor", False): - node_type["type"] = "sensor" - return { + node_operator = { "id": task.task_id, "label": task.label, - **is_mapped, - **setup_teardown_type, - **node_type, + "operator": task.operator_name, + "type": "task", } + if task.is_setup: + node_operator["setup_teardown_type"] = "setup" + elif task.is_teardown: + node_operator["setup_teardown_type"] = "teardown" + if isinstance(task, MappedOperator) or parent_group_is_mapped: + node_operator["is_mapped"] = True + return node_operator task_group = task_item_or_group is_mapped = isinstance(task_group, MappedTaskGroup) children = [ - task_group_to_dict(child) for child in sorted(task_group.children.values(), key=lambda t: t.label) + task_group_to_dict(child, parent_group_is_mapped=parent_group_is_mapped or is_mapped) + for child in get_task_group_children_getter()(task_group) ] if task_group.upstream_group_ids or task_group.upstream_task_ids: @@ -715,5 +723,44 @@ def task_group_to_dict(task_item_or_group): "tooltip": task_group.tooltip, "is_mapped": is_mapped, "children": children, - "type": "task_group", + "type": "task", + } + + +def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False): + """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" + from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator + from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.serialization.serialized_objects import SerializedBaseOperator + + if isinstance(task := task_item_or_group, (AbstractOperator, SerializedBaseOperator)): + is_mapped = None + if isinstance(task, MappedOperator) or parent_group_is_mapped: + is_mapped = True + setup_teardown_type = None + if task.is_setup is True: + setup_teardown_type = "setup" + elif task.is_teardown is True: + setup_teardown_type = "teardown" + return { + "id": task.task_id, + "label": task.label, + "is_mapped": is_mapped, + "children": None, + "setup_teardown_type": setup_teardown_type, + } + + task_group = task_item_or_group + task_group_sort = get_task_group_children_getter() + is_mapped_group = isinstance(task_group, MappedTaskGroup) + children = [ + task_group_to_dict_grid(x, parent_group_is_mapped=parent_group_is_mapped or is_mapped_group) + for x in task_group_sort(task_group) + ] + + return { + "id": task_group.group_id, + "label": task_group.label, + "is_mapped": is_mapped_group or None, + "children": children or None, } From a7aa377da17d28e8aeb03ab7985c85927cb7035a Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Thu, 17 Jul 2025 13:53:03 +0100 Subject: [PATCH 2/3] fixup! Deprecate and move `airflow.utils.task_group` to SDK --- .../tests/unit/standard/decorators/test_python.py | 9 ++++++++- .../unit/standard/operators/test_branch_operator.py | 7 ++++++- .../unit/standard/sensors/test_external_task_sensor.py | 8 +++++++- .../tests/unit/standard/utils/test_sensor_helper.py | 2 +- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index 1d65c02f533bb..180101fb26fde 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -37,6 +37,7 @@ from airflow.sdk.bases.decorator import DecoratedMappedOperator from airflow.sdk.definitions._internal.expandinput import DictOfListsExpandInput from airflow.sdk.definitions.mappedoperator import MappedOperator + else: from airflow.decorators import setup, task as task_decorator, teardown from airflow.decorators.base import DecoratedMappedOperator # type: ignore[no-redef] @@ -45,7 +46,13 @@ from airflow.models.expandinput import DictOfListsExpandInput from airflow.models.mappedoperator import MappedOperator from airflow.models.xcom_arg import XComArg - from airflow.utils.task_group import TaskGroup + from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] + +try: + from airflow.sdk.definitions.taskgroup import TaskGroup +except ImportError: + # Fallback for Airflow < 3.1 + from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] pytestmark = pytest.mark.db_test diff --git a/providers/standard/tests/unit/standard/operators/test_branch_operator.py b/providers/standard/tests/unit/standard/operators/test_branch_operator.py index 670ce77415ba0..bcecbd26b75f5 100644 --- a/providers/standard/tests/unit/standard/operators/test_branch_operator.py +++ b/providers/standard/tests/unit/standard/operators/test_branch_operator.py @@ -28,9 +28,14 @@ from airflow.timetables.base import DataInterval from airflow.utils import timezone from airflow.utils.state import State -from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunType +try: + from airflow.sdk.definitions.taskgroup import TaskGroup +except ImportError: + # Fallback for Airflow < 3.1 + from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1, AIRFLOW_V_3_0_PLUS if AIRFLOW_V_3_0_PLUS: diff --git a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py index daff66dcb13f8..cb881617faf35 100644 --- a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py +++ b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py @@ -66,7 +66,6 @@ from airflow.timetables.base import DataInterval from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.task_group import TaskGroup from airflow.utils.timezone import coerce_datetime, datetime from airflow.utils.types import DagRunType @@ -80,6 +79,13 @@ from airflow.utils.types import DagRunTriggeredByType else: from airflow.decorators import task as task_deco + +try: + from airflow.sdk.definitions.taskgroup import TaskGroup +except ImportError: + # Fallback for Airflow < 3.1 + from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] + pytestmark = pytest.mark.db_test diff --git a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py index 346e956a981bc..927bf100a4967 100644 --- a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py +++ b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py @@ -33,9 +33,9 @@ _get_count, _get_external_task_group_task_ids, ) +from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState -from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunType from tests_common.test_utils import db From 633970193afa4ac602dacae93552b6526e4a8c13 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Thu, 17 Jul 2025 15:32:27 +0100 Subject: [PATCH 3/3] fixup! fixup! Deprecate and move `airflow.utils.task_group` to SDK --- .../tests/unit/standard/decorators/test_python.py | 6 ------ .../tests/unit/standard/utils/test_sensor_helper.py | 8 +++++++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index 180101fb26fde..3300c51a9fdda 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -48,12 +48,6 @@ from airflow.models.xcom_arg import XComArg from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] -try: - from airflow.sdk.definitions.taskgroup import TaskGroup -except ImportError: - # Fallback for Airflow < 3.1 - from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] - pytestmark = pytest.mark.db_test diff --git a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py index 927bf100a4967..89735e206db5a 100644 --- a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py +++ b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py @@ -33,7 +33,6 @@ _get_count, _get_external_task_group_task_ids, ) -from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import DagRunType @@ -41,9 +40,16 @@ from tests_common.test_utils import db from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS +try: + from airflow.sdk.definitions.taskgroup import TaskGroup +except ImportError: + # Fallback for Airflow < 3.1 + from airflow.utils.task_group import TaskGroup # type: ignore[no-redef] + if TYPE_CHECKING: from sqlalchemy.orm.session import Session + TI = TaskInstance