Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/utils/dag_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

from typing import TYPE_CHECKING, TypeAlias, cast

from airflow.models.mappedoperator import MappedOperator
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator

if TYPE_CHECKING:
from collections.abc import Iterable

from airflow.models.mappedoperator import MappedOperator
from airflow.sdk import DAG

Operator: TypeAlias = MappedOperator | SerializedBaseOperator
Expand Down Expand Up @@ -65,7 +65,7 @@ def dag_edges(dag: DAG):

def collect_edges(task_group):
"""Update edges_to_add and edges_to_skip according to TaskGroups."""
if isinstance(task_group, (AbstractOperator, SerializedBaseOperator)):
if isinstance(task_group, (AbstractOperator, SerializedBaseOperator, MappedOperator)):
return

for target_id in task_group.downstream_group_ids:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,71 @@ def test_should_return_404_when_dag_version_not_found(self, test_client):
response.json()["detail"]
== "Dag with id dag_with_multiple_versions and version number 999 was not found"
)

def test_mapped_operator_graph_view(self, dag_maker, test_client, session):
"""
Ensures structure_data endpoint handles MappedOperator without AttributeError.
"""
from airflow.providers.standard.operators.bash import BashOperator

with dag_maker(
dag_id="test_mapped_operator_dag",
serialized=True,
session=session,
start_date=pendulum.DateTime(2023, 2, 1, 0, 0, 0, tzinfo=pendulum.UTC),
):
task1 = EmptyOperator(task_id="task1")
mapped_task = BashOperator.partial(
task_id="mapped_bash_task",
do_xcom_push=False,
).expand(bash_command=["echo 1", "echo 2", "echo 3"])
task2 = EmptyOperator(task_id="task2")

task1 >> mapped_task >> task2

dag_maker.sync_dagbag_to_db()
response = test_client.get("/structure/structure_data", params={"dag_id": "test_mapped_operator_dag"})
assert response.status_code == 200
data = response.json()

mapped_node = next(node for node in data["nodes"] if node["id"] == "mapped_bash_task")
assert mapped_node["is_mapped"] is True
assert mapped_node["operator"] == "BashOperator"
assert len(data["edges"]) == 2

def test_mapped_operator_in_task_group(self, dag_maker, test_client, session):
"""
Test that mapped operators within task groups are handled correctly.
Specifically tests task_group_to_dict function with MappedOperator instances.
"""
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.taskgroup import TaskGroup

with dag_maker(
dag_id="test_mapped_in_group_dag",
serialized=True,
session=session,
start_date=pendulum.DateTime(2023, 2, 1, 0, 0, 0, tzinfo=pendulum.UTC),
):
with TaskGroup(group_id="processing_group"):
prep = EmptyOperator(task_id="prep")
mapped = PythonOperator.partial(
task_id="process",
python_callable=lambda x: print(f"Processing {x}"),
).expand(op_args=[[1], [2], [3], [4]])

prep >> mapped

dag_maker.sync_dagbag_to_db()
response = test_client.get("/structure/structure_data", params={"dag_id": "test_mapped_in_group_dag"})

assert response.status_code == 200
data = response.json()
group_node = next(node for node in data["nodes"] if node["id"] == "processing_group")
assert group_node["children"] is not None

mapped_in_group = next(
child for child in group_node["children"] if child["id"] == "processing_group.process"
)
assert mapped_in_group["is_mapped"] is True
assert mapped_in_group["operator"] == "PythonOperator"
4 changes: 2 additions & 2 deletions task-sdk/src/airflow/sdk/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,11 +686,11 @@ def get_task_group_children_getter() -> Callable:

def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False):
"""Create a nested dict representation of this TaskGroup and its children used to construct the Graph."""
from airflow.models.mappedoperator import MappedOperator
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator

if isinstance(task := task_item_or_group, (AbstractOperator, SerializedBaseOperator)):
if isinstance(task := task_item_or_group, (AbstractOperator, SerializedBaseOperator, MappedOperator)):
node_operator = {
"id": task.task_id,
"label": task.label,
Expand Down
Loading