Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

import collections
from typing import TYPE_CHECKING, Annotated
from typing import TYPE_CHECKING, Annotated, Any

import structlog
from fastapi import Depends, HTTPException, status
Expand Down Expand Up @@ -48,6 +48,7 @@
from airflow.api_fastapi.core_api.security import requires_access_dag
from airflow.api_fastapi.core_api.services.ui.grid import (
_find_aggregates,
_get_aggs_for_node,
_merge_node_dicts,
)
from airflow.api_fastapi.core_api.services.ui.task_group import (
Expand Down Expand Up @@ -156,7 +157,7 @@ def get_dag_structure(
task_group_sort = get_task_group_children_getter()
if not run_ids:
nodes = [task_group_to_dict_grid(x) for x in task_group_sort(latest_dag.task_group)]
return nodes
return [GridNodeResponse(**n) for n in nodes]

serdags = session.scalars(
select(SerializedDagModel).where(
Expand All @@ -170,7 +171,7 @@ def get_dag_structure(
)
)
)
merged_nodes: list[GridNodeResponse] = []
merged_nodes: list[dict[str, Any]] = []
dags = [latest_dag]
for serdag in serdags:
if serdag:
Expand All @@ -179,7 +180,37 @@ def get_dag_structure(
nodes = [task_group_to_dict_grid(x) for x in task_group_sort(dag.task_group)]
_merge_node_dicts(merged_nodes, nodes)

return merged_nodes
# Ensure historical tasks (e.g. removed) that exist in TIs for the selected runs are represented
def _collect_ids(nodes: list[dict[str, Any]]) -> set[str]:
ids: set[str] = set()
for n in nodes:
nid = n.get("id")
if nid:
ids.add(nid)
children = n.get("children")
if children:
ids |= _collect_ids(children) # recurse
return ids

existing_ids = _collect_ids(merged_nodes)
historical_task_ids = session.scalars(
select(TaskInstance.task_id)
.join(TaskInstance.dag_run)
.where(TaskInstance.dag_id == dag_id, DagRun.id.in_(run_ids))
.distinct()
)
for task_id in historical_task_ids:
if task_id not in existing_ids:
merged_nodes.append(
{
"id": task_id,
"label": task_id,
"is_mapped": None,
"children": None,
}
)

return [GridNodeResponse(**n) for n in merged_nodes]


@grid_router.get(
Expand Down Expand Up @@ -345,19 +376,47 @@ def get_grid_ti_summaries(
assert serdag

def get_node_sumaries():
yielded_task_ids: set[str] = set()

# Yield all nodes discoverable from the serialized DAG structure
for node in _find_aggregates(
node=serdag.dag.task_group,
parent_node=None,
ti_details=ti_details,
):
if node["type"] == "task":
node["child_states"] = None
node["min_start_date"] = None
node["max_end_date"] = None
if node["type"] in {"task", "mapped_task"}:
yielded_task_ids.add(node["task_id"])
if node["type"] == "task":
node["child_states"] = None
node["min_start_date"] = None
node["max_end_date"] = None
yield node

# For good history: add synthetic leaf nodes for task_ids that have TIs in this run
# but are not present in the current DAG structure (e.g. removed tasks)
missing_task_ids = set(ti_details.keys()) - yielded_task_ids
for task_id in sorted(missing_task_ids):
detail = ti_details[task_id]
# Create a leaf task node with aggregated state from its TIs
agg = _get_aggs_for_node(detail)
yield {
"task_id": task_id,
"type": "task",
"parent_id": None,
**agg,
# Align with leaf behavior
"child_states": None,
"min_start_date": None,
"max_end_date": None,
}

task_instances = list(get_node_sumaries())
# If a group id and a task id collide, prefer the group record
group_ids = {n.get("task_id") for n in task_instances if n.get("type") == "group"}
filtered = [n for n in task_instances if not (n.get("type") == "task" and n.get("task_id") in group_ids)]

return { # type: ignore[return-value]
"run_id": run_id,
"dag_id": dag_id,
"task_instances": list(get_node_sumaries()),
"task_instances": filtered,
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,19 @@ def _find_aggregates(
"""Recursively fill the Task Group Map."""
node_id = node.node_id
parent_id = parent_node.node_id if parent_node else None
details = ti_details[node_id]
# Do not mutate ti_details by accidental key creation
details = ti_details.get(node_id, [])

if node is None:
return
if isinstance(node, MappedOperator):
# For unmapped tasks, reflect a single None state so UI shows one square
mapped_details = details or [{"state": None, "start_date": None, "end_date": None}]
yield {
"task_id": node_id,
"type": "mapped_task",
"parent_id": parent_id,
**_get_aggs_for_node(details),
**_get_aggs_for_node(mapped_details),
}

return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,39 @@ def test_should_response_200_with_deleted_task_and_taskgroup(self, session, test
},
]

# Also verify that TI summaries include a leaf entry for the removed task
ti_resp = test_client.get(f"/grid/ti_summaries/{DAG_ID_3}/run_3")
assert ti_resp.status_code == 200
ti_payload = ti_resp.json()
assert ti_payload["dag_id"] == DAG_ID_3
assert ti_payload["run_id"] == "run_3"
# Find the removed task summary; it should exist even if not in current serialized DAG structure
removed_ti = next(
(
n
for n in ti_payload["task_instances"]
if n["task_id"] == TASK_ID_4 and n["child_states"] is None
),
None,
)
assert removed_ti is not None
# Its state should be the aggregated state of its TIs, which includes 'removed'
assert removed_ti["state"] in (
"removed",
None,
"skipped",
"success",
"failed",
"running",
"queued",
"scheduled",
"deferred",
"restarting",
"up_for_retry",
"up_for_reschedule",
"upstream_failed",
)

def test_get_dag_structure(self, session, test_client):
session.commit()
response = test_client.get(f"/grid/structure/{DAG_ID}?limit=5")
Expand Down Expand Up @@ -690,3 +723,16 @@ def sort_dict(in_dict):
expected = sort_dict(expected)
actual = sort_dict(actual)
assert actual == expected

def test_structure_includes_historical_removed_task_with_proper_shape(self, session, test_client):
# Ensure the structure endpoint returns synthetic node for historical/removed task
response = test_client.get(f"/grid/structure/{DAG_ID_3}")
assert response.status_code == 200
nodes = response.json()
# Find the historical removed task id
t4 = next((n for n in nodes if n["id"] == TASK_ID_4), None)
assert t4 is not None
assert t4["label"] == TASK_ID_4
# Optional None fields are excluded from response due to response_model_exclude_none=True
assert "is_mapped" not in t4
assert "children" not in t4