diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py index f83bc53846fc0..62f289f4e48b3 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py @@ -35,6 +35,7 @@ from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import DagRunTriggeredByType, DagRunType +from tests_common.test_utils.asserts import assert_queries_count from tests_common.test_utils.db import clear_db_assets, clear_db_dags, clear_db_runs, clear_db_serialized_dags from tests_common.test_utils.mock_operators import MockOperator @@ -273,7 +274,8 @@ def _freeze_time_for_dagruns(time_machine): @pytest.mark.usefixtures("_freeze_time_for_dagruns") class TestGetGridDataEndpoint: def test_should_response_200(self, test_client): - response = test_client.get(f"/grid/runs/{DAG_ID}") + with assert_queries_count(5): + response = test_client.get(f"/grid/runs/{DAG_ID}") assert response.status_code == 200 assert response.json() == [ GRID_RUN_1, @@ -314,7 +316,8 @@ def test_should_response_200(self, test_client): ], ) def test_should_response_200_order_by(self, test_client, order_by, expected): - response = test_client.get(f"/grid/runs/{DAG_ID}", params={"order_by": order_by}) + with assert_queries_count(5): + response = test_client.get(f"/grid/runs/{DAG_ID}", params={"order_by": order_by}) assert response.status_code == 200 assert response.json() == expected @@ -332,7 +335,8 @@ def test_should_response_200_order_by(self, test_client, order_by, expected): ], ) def test_should_response_200_limit(self, test_client, limit, expected): - response = test_client.get(f"/grid/runs/{DAG_ID}", params={"limit": limit}) + with assert_queries_count(5): + response = test_client.get(f"/grid/runs/{DAG_ID}", params={"limit": limit}) assert response.status_code == 200 assert response.json() == expected @@ -356,15 +360,16 @@ def test_should_response_200_limit(self, test_client, limit, expected): ], ) def test_runs_should_response_200_date_filters(self, test_client, params, expected): - response = test_client.get( - f"/grid/runs/{DAG_ID}", - params=params, - ) + with assert_queries_count(5): + response = test_client.get( + f"/grid/runs/{DAG_ID}", + params=params, + ) assert response.status_code == 200 assert response.json() == expected @pytest.mark.parametrize( - ("params", "expected"), + ("params, expected, expected_queries_count"), [ ( { @@ -372,6 +377,7 @@ def test_runs_should_response_200_date_filters(self, test_client, params, expect "run_after_lte": timezone.datetime(2024, 11, 30), }, GRID_NODES, + 7, ), ( { @@ -379,14 +385,18 @@ def test_runs_should_response_200_date_filters(self, test_client, params, expect "run_after_lte": timezone.datetime(2024, 10, 30), }, GRID_NODES, + 5, ), ], ) - def test_structure_should_response_200_date_filters(self, test_client, params, expected): - response = test_client.get( - f"/grid/structure/{DAG_ID}", - params=params, - ) + def test_structure_should_response_200_date_filters( + self, test_client, params, expected, expected_queries_count + ): + with assert_queries_count(expected_queries_count): + response = test_client.get( + f"/grid/structure/{DAG_ID}", + params=params, + ) assert response.status_code == 200 assert response.json() == expected @@ -407,12 +417,14 @@ def test_should_response_404(self, test_client, endpoint): assert response.json() == {"detail": "Dag with id invalid_dag was not found"} def test_structure_should_response_200_without_dag_run(self, test_client): - response = test_client.get(f"/grid/structure/{DAG_ID_2}") + with assert_queries_count(5): + response = test_client.get(f"/grid/structure/{DAG_ID_2}") assert response.status_code == 200 assert response.json() == [{"id": "task2", "label": "task2"}] def test_runs_should_response_200_without_dag_run(self, test_client): - response = test_client.get(f"/grid/runs/{DAG_ID_2}") + with assert_queries_count(5): + response = test_client.get(f"/grid/runs/{DAG_ID_2}") assert response.status_code == 200 assert response.json() == [] @@ -426,7 +438,8 @@ def test_should_response_200_with_deleted_task_and_taskgroup(self, session, test ti.dag_version = session.scalar(select(DagModel).where(DagModel.dag_id == DAG_ID_3)).dag_versions[-1] session.commit() - response = test_client.get(f"/grid/structure/{DAG_ID_3}") + with assert_queries_count(7): + response = test_client.get(f"/grid/structure/{DAG_ID_3}") assert response.status_code == 200 assert response.json() == [ {"id": "task3", "label": "task3"}, @@ -439,7 +452,8 @@ def test_should_response_200_with_deleted_task_and_taskgroup(self, session, test ] # Also verify that TI summaries include a leaf entry for the removed task - ti_resp = test_client.get(f"/grid/ti_summaries/{DAG_ID_3}/run_3") + with assert_queries_count(4): + ti_resp = test_client.get(f"/grid/ti_summaries/{DAG_ID_3}/run_3") assert ti_resp.status_code == 200 ti_payload = ti_resp.json() assert ti_payload["dag_id"] == DAG_ID_3 @@ -473,7 +487,9 @@ def test_should_response_200_with_deleted_task_and_taskgroup(self, session, test def test_get_dag_structure(self, session, test_client): session.commit() - response = test_client.get(f"/grid/structure/{DAG_ID}?limit=5") + + with assert_queries_count(7): + response = test_client.get(f"/grid/structure/{DAG_ID}?limit=5") assert response.status_code == 200 assert response.json() == [ { @@ -506,7 +522,8 @@ def test_get_dag_structure(self, session, test_client): def test_get_grid_runs(self, session, test_client): session.commit() - response = test_client.get(f"/grid/runs/{DAG_ID}?limit=5") + with assert_queries_count(5): + response = test_client.get(f"/grid/runs/{DAG_ID}?limit=5") assert response.status_code == 200 assert response.json() == [ { @@ -562,7 +579,8 @@ def test_filter_by_triggering_user(self, session, test_client, endpoint, trigger def test_get_grid_runs_filter_by_run_type_and_triggering_user(self, session, test_client): session.commit() - response = test_client.get(f"/grid/runs/{DAG_ID}?run_type=manual&triggering_user=user2") + with assert_queries_count(5): + response = test_client.get(f"/grid/runs/{DAG_ID}?run_type=manual&triggering_user=user2") assert response.status_code == 200 assert response.json() == [GRID_RUN_2] @@ -585,7 +603,9 @@ def test_filter_by_state(self, session, test_client, endpoint, state, expected): def test_grid_ti_summaries_group(self, session, test_client): run_id = "run_4-1" session.commit() - response = test_client.get(f"/grid/ti_summaries/{DAG_ID_4}/{run_id}") + + with assert_queries_count(4): + response = test_client.get(f"/grid/ti_summaries/{DAG_ID_4}/{run_id}") assert response.status_code == 200 actual = response.json() expected = { @@ -665,7 +685,9 @@ def test_grid_ti_summaries_group(self, session, test_client): def test_grid_ti_summaries_mapped(self, session, test_client): run_id = "run_2" session.commit() - response = test_client.get(f"/grid/ti_summaries/{DAG_ID}/{run_id}") + + with assert_queries_count(4): + response = test_client.get(f"/grid/ti_summaries/{DAG_ID}/{run_id}") assert response.status_code == 200 data = response.json() actual = data["task_instances"] @@ -742,7 +764,9 @@ def sort_dict(in_dict): def test_structure_includes_historical_removed_task_with_proper_shape(self, session, test_client): # Ensure the structure endpoint returns synthetic node for historical/removed task - response = test_client.get(f"/grid/structure/{DAG_ID_3}") + + with assert_queries_count(7): + response = test_client.get(f"/grid/structure/{DAG_ID_3}") assert response.status_code == 200 nodes = response.json() # Find the historical removed task id