Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 47 additions & 23 deletions airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.db import clear_db_assets, clear_db_dags, clear_db_runs, clear_db_serialized_dags
from tests_common.test_utils.mock_operators import MockOperator

Expand Down Expand Up @@ -273,7 +274,8 @@ def _freeze_time_for_dagruns(time_machine):
@pytest.mark.usefixtures("_freeze_time_for_dagruns")
class TestGetGridDataEndpoint:
def test_should_response_200(self, test_client):
response = test_client.get(f"/grid/runs/{DAG_ID}")
with assert_queries_count(5):
response = test_client.get(f"/grid/runs/{DAG_ID}")
assert response.status_code == 200
assert response.json() == [
GRID_RUN_1,
Expand Down Expand Up @@ -314,7 +316,8 @@ def test_should_response_200(self, test_client):
],
)
def test_should_response_200_order_by(self, test_client, order_by, expected):
response = test_client.get(f"/grid/runs/{DAG_ID}", params={"order_by": order_by})
with assert_queries_count(5):
response = test_client.get(f"/grid/runs/{DAG_ID}", params={"order_by": order_by})
assert response.status_code == 200
assert response.json() == expected

Expand All @@ -332,7 +335,8 @@ def test_should_response_200_order_by(self, test_client, order_by, expected):
],
)
def test_should_response_200_limit(self, test_client, limit, expected):
response = test_client.get(f"/grid/runs/{DAG_ID}", params={"limit": limit})
with assert_queries_count(5):
response = test_client.get(f"/grid/runs/{DAG_ID}", params={"limit": limit})
assert response.status_code == 200
assert response.json() == expected

Expand All @@ -356,37 +360,43 @@ def test_should_response_200_limit(self, test_client, limit, expected):
],
)
def test_runs_should_response_200_date_filters(self, test_client, params, expected):
response = test_client.get(
f"/grid/runs/{DAG_ID}",
params=params,
)
with assert_queries_count(5):
response = test_client.get(
f"/grid/runs/{DAG_ID}",
params=params,
)
assert response.status_code == 200
assert response.json() == expected

@pytest.mark.parametrize(
("params", "expected"),
("params, expected, expected_queries_count"),
[
(
{
"run_after_gte": timezone.datetime(2024, 11, 30),
"run_after_lte": timezone.datetime(2024, 11, 30),
},
GRID_NODES,
7,
),
(
{
"run_after_gte": timezone.datetime(2024, 10, 30),
"run_after_lte": timezone.datetime(2024, 10, 30),
},
GRID_NODES,
5,
),
],
)
def test_structure_should_response_200_date_filters(self, test_client, params, expected):
response = test_client.get(
f"/grid/structure/{DAG_ID}",
params=params,
)
def test_structure_should_response_200_date_filters(
self, test_client, params, expected, expected_queries_count
):
with assert_queries_count(expected_queries_count):
response = test_client.get(
f"/grid/structure/{DAG_ID}",
params=params,
)
assert response.status_code == 200
assert response.json() == expected

Expand All @@ -407,12 +417,14 @@ def test_should_response_404(self, test_client, endpoint):
assert response.json() == {"detail": "Dag with id invalid_dag was not found"}

def test_structure_should_response_200_without_dag_run(self, test_client):
response = test_client.get(f"/grid/structure/{DAG_ID_2}")
with assert_queries_count(5):
response = test_client.get(f"/grid/structure/{DAG_ID_2}")
assert response.status_code == 200
assert response.json() == [{"id": "task2", "label": "task2"}]

def test_runs_should_response_200_without_dag_run(self, test_client):
response = test_client.get(f"/grid/runs/{DAG_ID_2}")
with assert_queries_count(5):
response = test_client.get(f"/grid/runs/{DAG_ID_2}")
assert response.status_code == 200
assert response.json() == []

Expand All @@ -426,7 +438,8 @@ def test_should_response_200_with_deleted_task_and_taskgroup(self, session, test
ti.dag_version = session.scalar(select(DagModel).where(DagModel.dag_id == DAG_ID_3)).dag_versions[-1]
session.commit()

response = test_client.get(f"/grid/structure/{DAG_ID_3}")
with assert_queries_count(7):
response = test_client.get(f"/grid/structure/{DAG_ID_3}")
assert response.status_code == 200
assert response.json() == [
{"id": "task3", "label": "task3"},
Expand All @@ -439,7 +452,8 @@ def test_should_response_200_with_deleted_task_and_taskgroup(self, session, test
]

# Also verify that TI summaries include a leaf entry for the removed task
ti_resp = test_client.get(f"/grid/ti_summaries/{DAG_ID_3}/run_3")
with assert_queries_count(4):
ti_resp = test_client.get(f"/grid/ti_summaries/{DAG_ID_3}/run_3")
assert ti_resp.status_code == 200
ti_payload = ti_resp.json()
assert ti_payload["dag_id"] == DAG_ID_3
Expand Down Expand Up @@ -473,7 +487,9 @@ def test_should_response_200_with_deleted_task_and_taskgroup(self, session, test

def test_get_dag_structure(self, session, test_client):
session.commit()
response = test_client.get(f"/grid/structure/{DAG_ID}?limit=5")

with assert_queries_count(7):
response = test_client.get(f"/grid/structure/{DAG_ID}?limit=5")
assert response.status_code == 200
assert response.json() == [
{
Expand Down Expand Up @@ -506,7 +522,8 @@ def test_get_dag_structure(self, session, test_client):

def test_get_grid_runs(self, session, test_client):
session.commit()
response = test_client.get(f"/grid/runs/{DAG_ID}?limit=5")
with assert_queries_count(5):
response = test_client.get(f"/grid/runs/{DAG_ID}?limit=5")
assert response.status_code == 200
assert response.json() == [
{
Expand Down Expand Up @@ -562,7 +579,8 @@ def test_filter_by_triggering_user(self, session, test_client, endpoint, trigger

def test_get_grid_runs_filter_by_run_type_and_triggering_user(self, session, test_client):
session.commit()
response = test_client.get(f"/grid/runs/{DAG_ID}?run_type=manual&triggering_user=user2")
with assert_queries_count(5):
response = test_client.get(f"/grid/runs/{DAG_ID}?run_type=manual&triggering_user=user2")
assert response.status_code == 200
assert response.json() == [GRID_RUN_2]

Expand All @@ -585,7 +603,9 @@ def test_filter_by_state(self, session, test_client, endpoint, state, expected):
def test_grid_ti_summaries_group(self, session, test_client):
run_id = "run_4-1"
session.commit()
response = test_client.get(f"/grid/ti_summaries/{DAG_ID_4}/{run_id}")

with assert_queries_count(4):
response = test_client.get(f"/grid/ti_summaries/{DAG_ID_4}/{run_id}")
assert response.status_code == 200
actual = response.json()
expected = {
Expand Down Expand Up @@ -665,7 +685,9 @@ def test_grid_ti_summaries_group(self, session, test_client):
def test_grid_ti_summaries_mapped(self, session, test_client):
run_id = "run_2"
session.commit()
response = test_client.get(f"/grid/ti_summaries/{DAG_ID}/{run_id}")

with assert_queries_count(4):
response = test_client.get(f"/grid/ti_summaries/{DAG_ID}/{run_id}")
assert response.status_code == 200
data = response.json()
actual = data["task_instances"]
Expand Down Expand Up @@ -742,7 +764,9 @@ def sort_dict(in_dict):

def test_structure_includes_historical_removed_task_with_proper_shape(self, session, test_client):
# Ensure the structure endpoint returns synthetic node for historical/removed task
response = test_client.get(f"/grid/structure/{DAG_ID_3}")

with assert_queries_count(7):
response = test_client.get(f"/grid/structure/{DAG_ID_3}")
assert response.status_code == 200
nodes = response.json()
# Find the historical removed task id
Expand Down
Loading