diff --git a/airflow-core/src/airflow/api_fastapi/common/db/dags.py b/airflow-core/src/airflow/api_fastapi/common/db/dags.py index cefce66279883..7707f78d419bc 100644 --- a/airflow-core/src/airflow/api_fastapi/common/db/dags.py +++ b/airflow-core/src/airflow/api_fastapi/common/db/dags.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING from sqlalchemy import func, select +from sqlalchemy.orm import selectinload from airflow.api_fastapi.common.db.common import ( apply_filters_to_select, @@ -33,7 +34,7 @@ def generate_dag_with_latest_run_query(max_run_filters: list[BaseParam], order_by: SortParam) -> Select: - query = select(DagModel) + query = select(DagModel).options(selectinload(DagModel.tags)) max_run_id_query = ( # ordering by id will not always be "latest run", but it's a simplifying assumption select(DagRun.dag_id, func.max(DagRun.id).label("max_dag_run_id")) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py index 7633cb94d03a5..96fa5e3e6732c 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py @@ -31,6 +31,7 @@ from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import DagRunTriggeredByType, DagRunType +from tests_common.test_utils.asserts import count_queries from tests_common.test_utils.db import ( clear_db_assets, clear_db_connections, @@ -524,6 +525,71 @@ def test_get_dags_filter_has_import_errors(self, session, test_client, filter_va assert body["total_entries"] == 1 assert [dag["dag_id"] for dag in body["dags"]] == expected_ids + def test_get_dags_no_n_plus_one_queries(self, session, test_client): + """Test that fetching DAGs with tags doesn't trigger n+1 queries.""" + num_dags = 5 + for i in range(num_dags): + dag_id = f"test_dag_queries_{i}" + dag_model = DagModel( + dag_id=dag_id, + bundle_name="dag_maker", + fileloc=f"/tmp/{dag_id}.py", + is_stale=False, + ) + session.add(dag_model) + session.flush() + + for j in range(3): + tag = DagTag(name=f"tag_{i}_{j}", dag_id=dag_id) + session.add(tag) + + session.commit() + session.expire_all() + + with count_queries() as result: + response = test_client.get("/dags", params={"limit": 10}) + + assert response.status_code == 200 + body = response.json() + dags_with_our_prefix = [d for d in body["dags"] if d["dag_id"].startswith("test_dag_queries_")] + assert len(dags_with_our_prefix) == num_dags + for dag in dags_with_our_prefix: + assert len(dag["tags"]) == 3 + + first_query_count = sum(result.values()) + + # Add more DAGs and verify query count doesn't scale linearly + for i in range(num_dags, num_dags + 3): + dag_id = f"test_dag_queries_{i}" + dag_model = DagModel( + dag_id=dag_id, + bundle_name="dag_maker", + fileloc=f"/tmp/{dag_id}.py", + is_stale=False, + ) + session.add(dag_model) + session.flush() + + for j in range(3): + tag = DagTag(name=f"tag_{i}_{j}", dag_id=dag_id) + session.add(tag) + + session.commit() + session.expire_all() + + with count_queries() as result2: + response = test_client.get("/dags", params={"limit": 15}) + + assert response.status_code == 200 + second_query_count = sum(result2.values()) + + # With n+1, adding 3 DAGs would add ~3 tag queries + # Without n+1, query count should remain nearly identical + assert second_query_count - first_query_count < 3, ( + f"Added 3 DAGs but query count increased by {second_query_count - first_query_count} " + f"({first_query_count} → {second_query_count}), suggesting n+1 queries for tags" + ) + class TestPatchDag(TestDagEndpoint): """Unit tests for Patch DAG.""" diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py index 2134a565ffcf0..328643362a129 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py @@ -26,6 +26,7 @@ from sqlalchemy.orm import Session from airflow.models import DagRun +from airflow.models.dag import DagModel, DagTag from airflow.models.dag_favorite import DagFavorite from airflow.models.hitl import HITLDetail from airflow.sdk.timezone import utcnow @@ -33,6 +34,7 @@ from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.types import DagRunTriggeredByType, DagRunType +from tests_common.test_utils.asserts import count_queries from unit.api_fastapi.core_api.routes.public.test_dags import ( DAG1_ID, DAG2_ID, @@ -233,6 +235,71 @@ def test_should_response_403(self, unauthorized_test_client): response = unauthorized_test_client.get("/dags", params={}) assert response.status_code == 403 + def test_get_dags_no_n_plus_one_queries(self, session, test_client): + """Test that fetching DAGs with tags doesn't trigger n+1 queries.""" + num_dags = 5 + for i in range(num_dags): + dag_id = f"test_dag_queries_ui_{i}" + dag_model = DagModel( + dag_id=dag_id, + bundle_name="dag_maker", + fileloc=f"/tmp/{dag_id}.py", + is_stale=False, + ) + session.add(dag_model) + session.flush() + + for j in range(3): + tag = DagTag(name=f"tag_ui_{i}_{j}", dag_id=dag_id) + session.add(tag) + + session.commit() + session.expire_all() + + with count_queries() as result: + response = test_client.get("/dags", params={"limit": 10}) + + assert response.status_code == 200 + body = response.json() + dags_with_our_prefix = [d for d in body["dags"] if d["dag_id"].startswith("test_dag_queries_ui_")] + assert len(dags_with_our_prefix) == num_dags + for dag in dags_with_our_prefix: + assert len(dag["tags"]) == 3 + + first_query_count = sum(result.values()) + + # Add more DAGs and verify query count doesn't scale linearly + for i in range(num_dags, num_dags + 3): + dag_id = f"test_dag_queries_ui_{i}" + dag_model = DagModel( + dag_id=dag_id, + bundle_name="dag_maker", + fileloc=f"/tmp/{dag_id}.py", + is_stale=False, + ) + session.add(dag_model) + session.flush() + + for j in range(3): + tag = DagTag(name=f"tag_ui_{i}_{j}", dag_id=dag_id) + session.add(tag) + + session.commit() + session.expire_all() + + with count_queries() as result2: + response = test_client.get("/dags", params={"limit": 15}) + + assert response.status_code == 200 + second_query_count = sum(result2.values()) + + # With n+1, adding 3 DAGs would add ~3 tag queries + # Without n+1, query count should remain nearly identical + assert second_query_count - first_query_count < 3, ( + f"Added 3 DAGs but query count increased by {second_query_count - first_query_count} " + f"({first_query_count} → {second_query_count}), suggesting n+1 queries for tags" + ) + @pytest.mark.usefixtures("configure_git_connection_for_dag_bundle") def test_latest_run_should_return_200(self, test_client): response = test_client.get(f"/dags/{DAG1_ID}/latest_run")