diff --git a/airflow-core/src/airflow/api_fastapi/app.py b/airflow-core/src/airflow/api_fastapi/app.py index 3f4956660ddc1..5f2d3a16b22e3 100644 --- a/airflow-core/src/airflow/api_fastapi/app.py +++ b/airflow-core/src/airflow/api_fastapi/app.py @@ -27,12 +27,12 @@ from airflow.api_fastapi.core_api.app import ( init_config, - init_dag_bag, init_error_handlers, init_flask_plugins, init_middlewares, init_views, ) +from airflow.api_fastapi.core_api.init_dagbag import get_dag_bag from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.configuration import conf from airflow.exceptions import AirflowConfigException @@ -81,13 +81,16 @@ def create_app(apps: str = "all") -> FastAPI: root_path=API_ROOT_PATH.removesuffix("/"), ) + dag_bag = get_dag_bag() + if "execution" in apps_list or "all" in apps_list: task_exec_api_app = create_task_execution_api_app() + task_exec_api_app.state.dag_bag = dag_bag init_error_handlers(task_exec_api_app) app.mount("/execution", task_exec_api_app) if "core" in apps_list or "all" in apps_list: - init_dag_bag(app) + app.state.dag_bag = dag_bag init_plugins(app) init_auth_manager(app) init_flask_plugins(app) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/app.py b/airflow-core/src/airflow/api_fastapi/core_api/app.py index 22ac4aba56e1a..48cc2d25a52ca 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/app.py @@ -30,7 +30,6 @@ from starlette.templating import Jinja2Templates from airflow.api_fastapi.auth.tokens import get_signing_key -from airflow.api_fastapi.core_api.init_dagbag import get_dag_bag from airflow.api_fastapi.core_api.middleware import FlaskExceptionsMiddleware from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -39,15 +38,6 @@ log = logging.getLogger(__name__) -def init_dag_bag(app: FastAPI) -> None: - """ - Create global DagBag for the FastAPI application. - - To access it use ``request.app.state.dag_bag``. - """ - app.state.dag_bag = get_dag_bag() - - def init_views(app: FastAPI) -> None: """Init views by registering the different routers.""" from airflow.api_fastapi.core_api.routes.public import public_router diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 6611aeb56de46..e9d7738f7a319 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -23,7 +23,7 @@ from uuid import UUID from cadwyn import VersionedAPIRouter -from fastapi import Body, Depends, HTTPException, Query, status +from fastapi import Body, Depends, HTTPException, Query, Request, status from pydantic import JsonValue from sqlalchemy import func, or_, tuple_, update from sqlalchemy.exc import NoResultFound, SQLAlchemyError @@ -49,12 +49,12 @@ from airflow.api_fastapi.execution_api.deps import JWTBearer from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun as DR -from airflow.models.taskinstance import TaskInstance as TI, _update_rtif +from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_tasks, _update_rtif from airflow.models.taskreschedule import TaskReschedule from airflow.models.trigger import Trigger from airflow.models.xcom import XComModel from airflow.utils import timezone -from airflow.utils.state import DagRunState, TaskInstanceState +from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState router = VersionedAPIRouter() @@ -255,6 +255,7 @@ def ti_update_state( task_instance_id: UUID, ti_patch_payload: Annotated[TIStateUpdate, Body()], session: SessionDep, + request: Request, ): """ Update the state of a TaskInstance. @@ -267,12 +268,13 @@ def ti_update_state( # We only use UUID above for validation purposes ti_id_str = str(task_instance_id) - old = select(TI.state, TI.try_number, TI.max_tries).where(TI.id == ti_id_str).with_for_update() + old = select(TI.state, TI.try_number, TI.max_tries, TI.dag_id).where(TI.id == ti_id_str).with_for_update() try: ( previous_state, try_number, max_tries, + dag_id, ) = session.execute(old).one() except NoResultFound: log.error("Task Instance %s not found", ti_id_str) @@ -308,6 +310,15 @@ def ti_update_state( updated_state = ti_patch_payload.state query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) query = query.values(state=updated_state) + + if updated_state == TerminalTIState.FAILED: + ti = session.get(TI, ti_id_str) + ser_dag = request.app.state.dag_bag.get_dag(dag_id) + if ser_dag and getattr(ser_dag, "fail_fast", False): + task_dict = getattr(ser_dag, "task_dict") + task_teardown_map = {k: v.is_teardown for k, v in task_dict.items()} + _stop_remaining_tasks(task_instance=ti, task_teardown_map=task_teardown_map, session=session) + elif isinstance(ti_patch_payload, TIRetryStatePayload): from airflow.models.taskinstance import uuid7 from airflow.models.taskinstancehistory import TaskInstanceHistory diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index e35927df61e34..e2a5fa5405d16 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -392,7 +392,7 @@ def set_current_context(context: Context) -> Generator[Context, None, None]: ) -def _stop_remaining_tasks(*, task_instance: TaskInstance, session: Session): +def _stop_remaining_tasks(*, task_instance: TaskInstance, task_teardown_map=None, session: Session): """ Stop non-teardown tasks in dag. @@ -411,13 +411,21 @@ def _stop_remaining_tasks(*, task_instance: TaskInstance, session: Session): TaskInstanceState.FAILED, ): continue - task = task_instance.task.dag.task_dict[ti.task_id] - if not task.is_teardown: + if task_teardown_map: + teardown = task_teardown_map[ti.task_id] + else: + task = task_instance.task.dag.task_dict[ti.task_id] + teardown = task.is_teardown + if not teardown: if ti.state == TaskInstanceState.RUNNING: log.info("Forcing task %s to fail due to dag's `fail_fast` setting", ti.task_id) + msg = "Forcing task to fail due to dag's `fail_fast` setting." + session.add(Log(event="fail task", extra=msg, task_instance=ti.key)) ti.error(session) else: log.info("Setting task %s to SKIPPED due to dag's `fail_fast` setting.", ti.task_id) + msg = "Skipping task due to dag's `fail_fast` setting." + session.add(Log(event="skip task", extra=msg, task_instance=ti.key)) ti.set_state(state=TaskInstanceState.SKIPPED, session=session) else: log.info("Not skipping teardown task '%s'", ti.task_id) diff --git a/airflow-core/src/airflow/serialization/schema.json b/airflow-core/src/airflow/serialization/schema.json index a30d08c3f4474..0670acd588cbe 100644 --- a/airflow-core/src/airflow/serialization/schema.json +++ b/airflow-core/src/airflow/serialization/schema.json @@ -176,6 +176,7 @@ } }, "catchup": { "type": "boolean" }, + "fail_fast": { "type": "boolean" }, "fileloc": { "type" : "string"}, "relative_fileloc": { "type" : "string"}, "_processor_dags_folder": { diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 456c3f9fa2a48..ab4436cb34088 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -596,8 +596,8 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta mock.patch( "airflow.api_fastapi.common.db.common.Session.execute", side_effect=[ - mock.Mock(one=lambda: ("running", 1, 0)), # First call returns "queued" - mock.Mock(one=lambda: ("running", 1, 0)), # Second call returns "queued" + mock.Mock(one=lambda: ("running", 1, 0, "dag")), # First call returns "queued" + mock.Mock(one=lambda: ("running", 1, 0, "dag")), # Second call returns "queued" SQLAlchemyError("Database error"), # Last call raises an error ], ), diff --git a/airflow-core/tests/unit/api_fastapi/test_app.py b/airflow-core/tests/unit/api_fastapi/test_app.py index 34ccfdfdd7b82..c2646ad0e50a7 100644 --- a/airflow-core/tests/unit/api_fastapi/test_app.py +++ b/airflow-core/tests/unit/api_fastapi/test_app.py @@ -32,17 +32,13 @@ def test_main_app_lifespan(client): assert test_app.state.lifespan_called, "Lifespan not called on Execution API app." -@mock.patch("airflow.api_fastapi.app.init_dag_bag") @mock.patch("airflow.api_fastapi.app.init_views") @mock.patch("airflow.api_fastapi.app.init_plugins") @mock.patch("airflow.api_fastapi.app.create_task_execution_api_app") -def test_core_api_app( - mock_create_task_exec_api, mock_init_plugins, mock_init_views, mock_init_dag_bag, client -): +def test_core_api_app(mock_create_task_exec_api, mock_init_plugins, mock_init_views, client): test_app = client(apps="core").app # Assert that core-related functions were called - mock_init_dag_bag.assert_called_once_with(test_app) mock_init_views.assert_called_once_with(test_app) mock_init_plugins.assert_called_once_with(test_app) @@ -50,20 +46,16 @@ def test_core_api_app( mock_create_task_exec_api.assert_not_called() -@mock.patch("airflow.api_fastapi.app.init_dag_bag") @mock.patch("airflow.api_fastapi.app.init_views") @mock.patch("airflow.api_fastapi.app.init_plugins") @mock.patch("airflow.api_fastapi.app.create_task_execution_api_app") -def test_execution_api_app( - mock_create_task_exec_api, mock_init_plugins, mock_init_views, mock_init_dag_bag, client -): +def test_execution_api_app(mock_create_task_exec_api, mock_init_plugins, mock_init_views, client): client(apps="execution") # Assert that execution-related functions were called mock_create_task_exec_api.assert_called_once() # Assert that core-related functions were NOT called - mock_init_dag_bag.assert_not_called() mock_init_views.assert_not_called() mock_init_plugins.assert_not_called() @@ -78,15 +70,13 @@ def test_execution_api_app_lifespan(client): assert execution_app[0].state.lifespan_called, "Lifespan not called on Execution API app." -@mock.patch("airflow.api_fastapi.app.init_dag_bag") @mock.patch("airflow.api_fastapi.app.init_views") @mock.patch("airflow.api_fastapi.app.init_plugins") @mock.patch("airflow.api_fastapi.app.create_task_execution_api_app") -def test_all_apps(mock_create_task_exec_api, mock_init_plugins, mock_init_views, mock_init_dag_bag, client): +def test_all_apps(mock_create_task_exec_api, mock_init_plugins, mock_init_views, client): test_app = client(apps="all").app # Assert that core-related functions were called - mock_init_dag_bag.assert_called_once_with(test_app) mock_init_views.assert_called_once_with(test_app) mock_init_plugins.assert_called_once_with(test_app) diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 9d98534a5e8f3..3adc2a98a0fb2 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1045,7 +1045,6 @@ def _validate_owner_links(self, _, owner_links): "has_on_success_callback", "has_on_failure_callback", "auto_register", - "fail_fast", "schedule", }