diff --git a/airflow-core/src/airflow/api_fastapi/app.py b/airflow-core/src/airflow/api_fastapi/app.py index b42d7e85bf3b2..adee9f0cc1545 100644 --- a/airflow-core/src/airflow/api_fastapi/app.py +++ b/airflow-core/src/airflow/api_fastapi/app.py @@ -31,7 +31,6 @@ init_middlewares, init_views, ) -from airflow.api_fastapi.core_api.init_dagbag import get_dag_bag from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.configuration import conf from airflow.exceptions import AirflowConfigException @@ -80,16 +79,12 @@ def create_app(apps: str = "all") -> FastAPI: version="2", ) - dag_bag = get_dag_bag() - if "execution" in apps_list or "all" in apps_list: task_exec_api_app = create_task_execution_api_app() - task_exec_api_app.state.dag_bag = dag_bag init_error_handlers(task_exec_api_app) app.mount("/execution", task_exec_api_app) if "core" in apps_list or "all" in apps_list: - app.state.dag_bag = dag_bag init_plugins(app) init_auth_manager(app) init_flask_plugins(app) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/init_dagbag.py b/airflow-core/src/airflow/api_fastapi/common/deps.py similarity index 84% rename from airflow-core/src/airflow/api_fastapi/core_api/init_dagbag.py rename to airflow-core/src/airflow/api_fastapi/common/deps.py index 720276d054b94..504cd0bbbd5c9 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/init_dagbag.py +++ b/airflow-core/src/airflow/api_fastapi/common/deps.py @@ -14,16 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations import os +from typing import Annotated + +from fastapi import Depends -from airflow.models import DagBag +from airflow.models.dagbag import DagBag from airflow.settings import DAGS_FOLDER -def get_dag_bag() -> DagBag: - """Instantiate the appropriate DagBag based on the ``SKIP_DAGS_PARSING`` environment variable.""" +def _get_dag_bag() -> DagBag: if os.environ.get("SKIP_DAGS_PARSING") == "True": return DagBag(os.devnull, include_examples=False) return DagBag(DAGS_FOLDER, read_dags_from_db=True) + + +DagBagDep = Annotated[DagBag, Depends(_get_dag_bag)] diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py index 2f08cbc71360c..3a1030b04d61b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py @@ -20,11 +20,12 @@ from datetime import datetime from typing import TYPE_CHECKING, Annotated -from fastapi import Depends, HTTPException, Request, status +from fastapi import Depends, HTTPException, status from sqlalchemy import and_, delete, func, select from sqlalchemy.orm import joinedload, subqueryload from airflow.api_fastapi.common.db.common import SessionDep, paginated_select +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.parameters import ( BaseParam, FilterParam, @@ -345,7 +346,7 @@ def create_asset_event( ) def materialize_asset( asset_id: int, - request: Request, + dag_bag: DagBagDep, session: SessionDep, ) -> DAGRunResponse: """Materialize an asset by triggering a DAG run that produces it.""" @@ -367,7 +368,7 @@ def materialize_asset( ) dag: DAG | None - if not (dag := request.app.state.dag_bag.get_dag(dag_id)): + if not (dag := dag_bag.get_dag(dag_id)): raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with ID `{dag_id}` was not found") return dag.create_dagrun( diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py index a3a70009360d1..b6d88aa170b2f 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -20,7 +20,7 @@ from typing import Annotated, Literal, cast import structlog -from fastapi import Depends, HTTPException, Query, Request, status +from fastapi import Depends, HTTPException, Query, status from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from sqlalchemy import select @@ -33,6 +33,7 @@ ) from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import SessionDep, paginated_select +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.parameters import ( FilterOptionEnum, FilterParam, @@ -147,7 +148,7 @@ def patch_dag_run( dag_run_id: str, patch_body: DAGRunPatchBody, session: SessionDep, - request: Request, + dag_bag: DagBagDep, user: GetUserDep, update_mask: list[str] | None = Query(None), ) -> DAGRunResponse: @@ -161,7 +162,7 @@ def patch_dag_run( f"The DagRun with dag_id: `{dag_id}` and run_id: `{dag_run_id}` was not found", ) - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found") @@ -255,7 +256,7 @@ def clear_dag_run( dag_id: str, dag_run_id: str, body: DAGRunClearBody, - request: Request, + dag_bag: DagBagDep, session: SessionDep, ) -> TaskInstanceCollectionResponse | DAGRunResponse: dag_run = session.scalar( @@ -267,7 +268,7 @@ def clear_dag_run( f"The DagRun with dag_id: `{dag_id}` and run_id: `{dag_run_id}` was not found", ) - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) if body.dry_run: task_instances = dag.clear( @@ -331,7 +332,7 @@ def get_dag_runs( ], readable_dag_runs_filter: ReadableDagRunsFilterDep, session: SessionDep, - request: Request, + dag_bag: DagBagDep, ) -> DAGRunCollectionResponse: """ Get all DAG Runs. @@ -341,7 +342,7 @@ def get_dag_runs( query = select(DagRun) if dag_id != "~": - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"The DAG with dag_id: `{dag_id}` was not found") @@ -389,7 +390,7 @@ def get_dag_runs( def trigger_dag_run( dag_id, body: TriggerDAGRunPostBody, - request: Request, + dag_bag: DagBagDep, user: GetUserDep, session: SessionDep, ) -> DAGRunResponse: @@ -405,7 +406,7 @@ def trigger_dag_run( ) try: - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) params = body.validate_context(dag) dag_run = dag.create_dagrun( diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py index b41f203a7df88..af9e16a374998 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_versions.py @@ -18,11 +18,12 @@ from typing import Annotated -from fastapi import Depends, HTTPException, Request, status +from fastapi import Depends, HTTPException, status from sqlalchemy import select from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import SessionDep, paginated_select +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.parameters import ( FilterParam, QueryLimit, @@ -80,10 +81,9 @@ def get_dag_version( ) def get_dag_versions( dag_id: str, + session: SessionDep, limit: QueryLimit, offset: QueryOffset, - session: SessionDep, - request: Request, version_number: Annotated[ FilterParam[int], Depends(filter_param_factory(DagVersion.version_number, int)) ], @@ -97,6 +97,7 @@ def get_dag_versions( SortParam(["id", "version_number", "bundle_name", "bundle_version"], DagVersion).dynamic_depends() ), ], + dag_bag: DagBagDep, ) -> DAGVersionCollectionResponse: """ Get all DAG Versions. @@ -106,7 +107,7 @@ def get_dag_versions( query = select(DagVersion) if dag_id != "~": - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"The DAG with dag_id: `{dag_id}` was not found") diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dags.py index a70034e835975..7473f334f9235 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dags.py @@ -19,7 +19,7 @@ from typing import Annotated -from fastapi import Depends, HTTPException, Query, Request, Response, status +from fastapi import Depends, HTTPException, Query, Response, status from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from sqlalchemy import select, update @@ -30,6 +30,7 @@ paginated_select, ) from airflow.api_fastapi.common.db.dags import generate_dag_with_latest_run_query +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.parameters import ( FilterOptionEnum, FilterParam, @@ -165,9 +166,13 @@ def get_dags( ), dependencies=[Depends(requires_access_dag(method="GET"))], ) -def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse: +def get_dag( + dag_id: str, + session: SessionDep, + dag_bag: DagBagDep, +) -> DAGResponse: """Get basic information about a DAG.""" - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found") @@ -192,9 +197,9 @@ def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse: ), dependencies=[Depends(requires_access_dag(method="GET"))], ) -def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> DAGDetailsResponse: +def get_dag_details(dag_id: str, session: SessionDep, dag_bag: DagBagDep) -> DAGDetailsResponse: """Get details of DAG.""" - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found") diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/extra_links.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/extra_links.py index 867bd8033fc4f..b214f33a4e75d 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/extra_links.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/extra_links.py @@ -19,10 +19,11 @@ from typing import TYPE_CHECKING -from fastapi import Depends, HTTPException, Request, status +from fastapi import Depends, HTTPException, status from sqlalchemy.sql import select from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.extra_links import ExtraLinkCollectionResponse from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc @@ -49,13 +50,13 @@ def get_extra_links( dag_run_id: str, task_id: str, session: SessionDep, - request: Request, + dag_bag: DagBagDep, map_index: int = -1, ) -> ExtraLinkCollectionResponse: """Get extra links for task instance.""" from airflow.models.taskinstance import TaskInstance - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with ID = {dag_id} not found") diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py index 5645776fb49bb..8a643f3a2af78 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/log.py @@ -27,6 +27,7 @@ from sqlalchemy.sql import select from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.headers import HeaderAcceptJsonOrText from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.common.types import Mimetype @@ -76,6 +77,7 @@ def get_log( try_number: PositiveInt, accept: HeaderAcceptJsonOrText, request: Request, + dag_bag: DagBagDep, session: SessionDep, full_content: bool = False, map_index: int = -1, @@ -129,7 +131,7 @@ def get_log( metadata["end_of_log"] = True raise HTTPException(status.HTTP_404_NOT_FOUND, "TaskInstance not found") - dag = request.app.state.dag_bag.get_dag(dag_id) + dag = dag_bag.get_dag(dag_id) if dag: with contextlib.suppress(TaskNotFound): ti.task = dag.get_task(ti.task_id) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py index fa2387f63b947..429193e176527 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -20,7 +20,7 @@ from typing import Annotated, Literal, cast import structlog -from fastapi import Depends, HTTPException, Query, Request, status +from fastapi import Depends, HTTPException, Query, status from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from sqlalchemy import or_, select @@ -30,6 +30,7 @@ from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import SessionDep, paginated_select +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.parameters import ( FilterOptionEnum, FilterParam, @@ -122,7 +123,7 @@ def get_mapped_task_instances( dag_id: str, dag_run_id: str, task_id: str, - request: Request, + dag_bag: DagBagDep, run_after_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("run_after", TI))], logical_date_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("logical_date", TI))], start_date_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("start_date", TI))], @@ -177,7 +178,7 @@ def get_mapped_task_instances( # 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404 unfiltered_total_count = get_query_count(query, session=session) if unfiltered_total_count == 0: - dag = request.app.state.dag_bag.get_dag(dag_id) + dag = dag_bag.get_dag(dag_id) if not dag: error_message = f"DAG {dag_id} not found" raise HTTPException(status.HTTP_404_NOT_FOUND, error_message) @@ -235,7 +236,7 @@ def get_task_instance_dependencies( dag_run_id: str, task_id: str, session: SessionDep, - request: Request, + dag_bag: DagBagDep, map_index: int = -1, ) -> TaskDependencyCollectionResponse: """Get dependencies blocking task from getting scheduled.""" @@ -254,7 +255,7 @@ def get_task_instance_dependencies( deps = [] if ti.state in [None, TaskInstanceState.SCHEDULED]: - dag = request.app.state.dag_bag.get_dag(ti.dag_id) + dag = dag_bag.get_dag(ti.dag_id) if dag: try: @@ -380,7 +381,7 @@ def get_mapped_task_instance( def get_task_instances( dag_id: str, dag_run_id: str, - request: Request, + dag_bag: DagBagDep, task_id: Annotated[FilterParam[str | None], Depends(filter_param_factory(TI.task_id, str | None))], run_after_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("run_after", TI))], logical_date_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("logical_date", TI))], @@ -442,7 +443,7 @@ def get_task_instances( ) if dag_id != "~": - dag = request.app.state.dag_bag.get_dag(dag_id) + dag = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id: `{dag_id}` was not found") query = query.where(TI.dag_id == dag_id) @@ -645,12 +646,12 @@ def get_mapped_task_instance_try_details( ) def post_clear_task_instances( dag_id: str, - request: Request, + dag_bag: DagBagDep, body: ClearTaskInstancesBody, session: SessionDep, ) -> TaskInstanceCollectionResponse: """Clear task instances.""" - dag = request.app.state.dag_bag.get_dag(dag_id) + dag = dag_bag.get_dag(dag_id) if not dag: error_message = f"DAG {dag_id} not found" raise HTTPException(status.HTTP_404_NOT_FOUND, error_message) @@ -703,7 +704,7 @@ def post_clear_task_instances( dry_run=True, run_id=None if past or future else dag_run_id, task_ids=task_ids, - dag_bag=request.app.state.dag_bag, + dag_bag=dag_bag, session=session, **body.model_dump( include={ @@ -733,13 +734,13 @@ def _patch_ti_validate_request( dag_id: str, dag_run_id: str, task_id: str, - request: Request, + dag_bag: DagBagDep, body: PatchTaskInstanceBody, session: SessionDep, map_index: int = -1, update_mask: list[str] | None = Query(None), ) -> tuple[DAG, TI, dict]: - dag = request.app.state.dag_bag.get_dag(dag_id) + dag = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not found") @@ -800,7 +801,7 @@ def patch_task_instance_dry_run( dag_id: str, dag_run_id: str, task_id: str, - request: Request, + dag_bag: DagBagDep, body: PatchTaskInstanceBody, session: SessionDep, map_index: int = -1, @@ -808,7 +809,7 @@ def patch_task_instance_dry_run( ) -> TaskInstanceCollectionResponse: """Update a task instance dry_run mode.""" dag, ti, data = _patch_ti_validate_request( - dag_id, dag_run_id, task_id, request, body, session, map_index, update_mask + dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask ) tis: list[TI] = [] @@ -870,7 +871,7 @@ def patch_task_instance( dag_id: str, dag_run_id: str, task_id: str, - request: Request, + dag_bag: DagBagDep, body: PatchTaskInstanceBody, user: GetUserDep, session: SessionDep, @@ -879,7 +880,7 @@ def patch_task_instance( ) -> TaskInstanceResponse: """Update a task instance.""" dag, ti, data = _patch_ti_validate_request( - dag_id, dag_run_id, task_id, request, body, session, map_index, update_mask + dag_id, dag_run_id, task_id, dag_bag, body, session, map_index, update_mask ) for key, _ in data.items(): diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/tasks.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/tasks.py index 6cd46c4e67b01..9b88211bd2fcf 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/tasks.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/tasks.py @@ -20,9 +20,10 @@ from operator import attrgetter from typing import cast -from fastapi import Depends, HTTPException, Request, status +from fastapi import Depends, HTTPException, status from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.tasks import TaskCollectionResponse, TaskResponse from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc @@ -45,11 +46,11 @@ ) def get_tasks( dag_id: str, - request: Request, + dag_bag: DagBagDep, order_by: str = "task_id", ) -> TaskCollectionResponse: """Get tasks for DAG.""" - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found") try: @@ -72,9 +73,9 @@ def get_tasks( ), dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK))], ) -def get_task(dag_id: str, task_id, request: Request) -> TaskResponse: +def get_task(dag_id: str, task_id, dag_bag: DagBagDep) -> TaskResponse: """Get simplified representation of a task.""" - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found") try: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py index 2dc746ddfe262..bb9c60ab3eeb3 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -19,12 +19,13 @@ import copy from typing import Annotated -from fastapi import Depends, HTTPException, Query, Request, status +from fastapi import Depends, HTTPException, Query, status from sqlalchemy import and_, select from sqlalchemy.orm import joinedload from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import SessionDep, paginated_select +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.xcom import ( @@ -185,11 +186,11 @@ def create_xcom_entry( dag_run_id: str, request_body: XComCreateBody, session: SessionDep, - request: Request, + dag_bag: DagBagDep, ) -> XComResponseNative: """Create an XCom entry.""" # Validate DAG ID - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with ID: `{dag_id}` was not found") diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py index caa3d2b0a12d5..f95bfb012dafb 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/assets.py @@ -17,10 +17,11 @@ from __future__ import annotations -from fastapi import Depends, HTTPException, Request, status +from fastapi import Depends, HTTPException, status from sqlalchemy import and_, func, select from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.security import requires_access_asset, requires_access_dag from airflow.models import DagModel @@ -35,10 +36,10 @@ ) def next_run_assets( dag_id: str, - request: Request, + dag_bag: DagBagDep, session: SessionDep, ) -> dict: - dag = request.app.state.dag_bag.get_dag(dag_id) + dag = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"can't find dag {dag_id}") diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py index d1fe586c61fa3..e6de9751428c7 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py @@ -22,13 +22,14 @@ from typing import Annotated import structlog -from fastapi import Depends, HTTPException, Request, status +from fastapi import Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.orm import joinedload from airflow import DAG from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import SessionDep, paginated_select +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.parameters import ( QueryDagRunRunTypesFilter, QueryDagRunStateFilter, @@ -75,7 +76,7 @@ def grid_data( dag_id: str, session: SessionDep, offset: QueryOffset, - request: Request, + dag_bag: DagBagDep, run_type: QueryDagRunRunTypesFilter, state: QueryDagRunStateFilter, limit: QueryLimit, @@ -90,7 +91,7 @@ def grid_data( root: str | None = None, ) -> GridResponse: """Return grid data.""" - dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + dag: DAG = dag_bag.get_dag(dag_id) if not dag: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found") diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 00a5ea9a5e627..1dfc9bb10e073 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -25,7 +25,7 @@ import structlog from cadwyn import VersionedAPIRouter -from fastapi import Body, Depends, HTTPException, Query, Request, status +from fastapi import Body, Depends, HTTPException, Query, status from pydantic import JsonValue from sqlalchemy import func, or_, tuple_, update from sqlalchemy.exc import NoResultFound, SQLAlchemyError @@ -34,6 +34,7 @@ from structlog.contextvars import bind_contextvars from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.common.deps import DagBagDep from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( PrevSuccessfulDagRunResponse, @@ -91,7 +92,7 @@ def ti_run( task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep, - request: Request, + dag_bag: DagBagDep, ) -> TIRunContext: """ Run a TaskInstance. @@ -242,7 +243,7 @@ def ti_run( or 0 ) - if dag := request.app.state.dag_bag.get_dag(ti.dag_id): + if dag := dag_bag.get_dag(ti.dag_id): upstream_map_indexes = dict(_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index)) else: upstream_map_indexes = None @@ -306,7 +307,7 @@ def ti_update_state( task_instance_id: UUID, ti_patch_payload: Annotated[TIStateUpdate, Body()], session: SessionDep, - request: Request, + dag_bag: DagBagDep, ): """ Update the state of a TaskInstance. @@ -371,7 +372,7 @@ def ti_update_state( if updated_state == TerminalTIState.FAILED: ti = session.get(TI, ti_id_str) - ser_dag = request.app.state.dag_bag.get_dag(dag_id) + ser_dag = dag_bag.get_dag(dag_id) if ser_dag and getattr(ser_dag, "fail_fast", False): task_dict = getattr(ser_dag, "task_dict") task_teardown_map = {k: v.is_teardown for k, v in task_dict.items()} diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py index 4366d8bbd9d25..91da4b46de721 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_extra_links.py @@ -20,6 +20,7 @@ import pytest +from airflow.api_fastapi.common.deps import _get_dag_bag from airflow.api_fastapi.core_api.datamodels.extra_links import ExtraLinkCollectionResponse from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.dagbag import DagBag @@ -94,7 +95,8 @@ def setup(self, test_client, dag_maker, request, session) -> None: DagBundlesManager().sync_bundles_to_db() dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = {self.dag.dag_id: self.dag} - test_client.app.state.dag_bag = dag_bag + + test_client.app.dependency_overrides[_get_dag_bag] = lambda: dag_bag dag_bag.sync_to_db("dags-folder", None) self.dag.create_dagrun( diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py index d8908f68b946a..0aecbe916a406 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_log.py @@ -27,6 +27,7 @@ from itsdangerous.url_safe import URLSafeSerializer from uuid6 import uuid7 +from airflow.api_fastapi.common.deps import _get_dag_bag from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.models.dag import DAG from airflow.providers.standard.operators.empty import EmptyOperator @@ -71,8 +72,6 @@ def add_one(x: int): start_date=timezone.parse(self.default_time), ) - self.app.state.dag_bag.bag_dag(dag) - for ti in dr.task_instances: ti.try_number = 1 ti.hostname = "localhost" @@ -96,7 +95,6 @@ def add_one(x: int): logical_date=timezone.parse(self.default_time), start_date=timezone.parse(self.default_time), ) - self.app.state.dag_bag.bag_dag(dummy_dag) for ti in dr2.task_instances: ti.try_number = 1 @@ -111,7 +109,10 @@ def add_one(x: int): session.flush() session.flush() - ... + dagbag = _get_dag_bag() + dagbag.bag_dag(dag) + dagbag.bag_dag(dummy_dag) + test_client.app.dependency_overrides[_get_dag_bag] = lambda: dagbag @pytest.fixture def configure_loggers(self, tmp_path, create_log_template): @@ -265,11 +266,12 @@ def test_get_logs_of_removed_task(self, request_url, expected_filename, extra_qu expected_filename = expected_filename.replace("LOG_DIR", str(self.log_dir)) # Recreate DAG without tasks - dagbag = self.app.state.dag_bag + dagbag = _get_dag_bag() dag = DAG(self.DAG_ID, schedule=None, start_date=timezone.parse(self.default_time)) - del dagbag.dags[self.DAG_ID] dagbag.bag_dag(dag=dag) + self.app.dependency_overrides[_get_dag_bag] = lambda: dagbag + key = self.app.state.secret_key serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py index 7052aeaedb631..b3a040b3250c8 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_tasks.py @@ -17,11 +17,11 @@ from __future__ import annotations import os -import unittest from datetime import datetime import pytest +from airflow.api_fastapi.common.deps import _get_dag_bag from airflow.models.dag import DAG from airflow.models.dagbag import DagBag from airflow.models.serialized_dag import SerializedDagModel @@ -70,7 +70,7 @@ def create_dags(self, test_client): mapped_dag.dag_id: mapped_dag, unscheduled_dag.dag_id: unscheduled_dag, } - test_client.app.state.dag_bag = dag_bag + test_client.app.dependency_overrides[_get_dag_bag] = lambda: dag_bag @staticmethod def clear_db(): @@ -229,13 +229,18 @@ def test_unscheduled_task(self, test_client): def test_should_respond_200_serialized(self, test_client, testing_dag_bundle): # Get the dag out of the dagbag before we patch it to an empty one - dag = test_client.app.state.dag_bag.get_dag(self.dag_id) + + with DAG(self.dag_id, schedule=None, start_date=self.task1_start_date, doc_md="details") as dag: + task1 = EmptyOperator(task_id=self.task_id, params={"foo": "bar"}) + task2 = EmptyOperator(task_id=self.task_id2, start_date=self.task2_start_date) + + task1 >> task2 + dag.sync_to_db() SerializedDagModel.write_dag(dag, bundle_name="test_bundle") dag_bag = DagBag(os.devnull, include_examples=False, read_dags_from_db=True) - patcher = unittest.mock.patch.object(test_client.app.state, "dag_bag", dag_bag) - patcher.start() + test_client.app.dependency_overrides[_get_dag_bag] = lambda: dag_bag expected = { "class_ref": { @@ -281,7 +286,6 @@ def test_should_respond_200_serialized(self, test_client, testing_dag_bundle): ) assert response.status_code == 200 assert response.json() == expected - patcher.stop() def test_should_respond_404(self, test_client): task_id = "xxxx_not_existing"