Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions airflow-core/src/airflow/api_fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@

from airflow.api_fastapi.core_api.app import (
init_config,
init_dag_bag,
init_error_handlers,
init_flask_plugins,
init_middlewares,
init_views,
)
from airflow.api_fastapi.core_api.init_dagbag import get_dag_bag
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
Expand Down Expand Up @@ -81,13 +81,16 @@ def create_app(apps: str = "all") -> FastAPI:
root_path=API_ROOT_PATH.removesuffix("/"),
)

dag_bag = get_dag_bag()

if "execution" in apps_list or "all" in apps_list:
task_exec_api_app = create_task_execution_api_app()
task_exec_api_app.state.dag_bag = dag_bag
init_error_handlers(task_exec_api_app)
app.mount("/execution", task_exec_api_app)

if "core" in apps_list or "all" in apps_list:
init_dag_bag(app)
app.state.dag_bag = dag_bag
init_plugins(app)
init_auth_manager(app)
init_flask_plugins(app)
Expand Down
10 changes: 0 additions & 10 deletions airflow-core/src/airflow/api_fastapi/core_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from starlette.templating import Jinja2Templates

from airflow.api_fastapi.auth.tokens import get_signing_key
from airflow.api_fastapi.core_api.init_dagbag import get_dag_bag
from airflow.api_fastapi.core_api.middleware import FlaskExceptionsMiddleware
from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand All @@ -39,15 +38,6 @@
log = logging.getLogger(__name__)


def init_dag_bag(app: FastAPI) -> None:
"""
Create global DagBag for the FastAPI application.

To access it use ``request.app.state.dag_bag``.
"""
app.state.dag_bag = get_dag_bag()


def init_views(app: FastAPI) -> None:
"""Init views by registering the different routers."""
from airflow.api_fastapi.core_api.routes.public import public_router
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from uuid import UUID

from cadwyn import VersionedAPIRouter
from fastapi import Body, Depends, HTTPException, Query, status
from fastapi import Body, Depends, HTTPException, Query, Request, status
from pydantic import JsonValue
from sqlalchemy import func, or_, tuple_, update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
Expand All @@ -49,12 +49,12 @@
from airflow.api_fastapi.execution_api.deps import JWTBearer
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_tasks, _update_rtif
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.models.xcom import XComModel
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState

router = VersionedAPIRouter()

Expand Down Expand Up @@ -255,6 +255,7 @@ def ti_update_state(
task_instance_id: UUID,
ti_patch_payload: Annotated[TIStateUpdate, Body()],
session: SessionDep,
request: Request,
):
"""
Update the state of a TaskInstance.
Expand All @@ -267,12 +268,13 @@ def ti_update_state(
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)

old = select(TI.state, TI.try_number, TI.max_tries).where(TI.id == ti_id_str).with_for_update()
old = select(TI.state, TI.try_number, TI.max_tries, TI.dag_id).where(TI.id == ti_id_str).with_for_update()
try:
(
previous_state,
try_number,
max_tries,
dag_id,
) = session.execute(old).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
Expand Down Expand Up @@ -308,6 +310,15 @@ def ti_update_state(
updated_state = ti_patch_payload.state
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=updated_state)

if updated_state == TerminalTIState.FAILED:
ti = session.get(TI, ti_id_str)
ser_dag = request.app.state.dag_bag.get_dag(dag_id)
if ser_dag and getattr(ser_dag, "fail_fast", False):
task_dict = getattr(ser_dag, "task_dict")
task_teardown_map = {k: v.is_teardown for k, v in task_dict.items()}
_stop_remaining_tasks(task_instance=ti, task_teardown_map=task_teardown_map, session=session)

elif isinstance(ti_patch_payload, TIRetryStatePayload):
from airflow.models.taskinstance import uuid7
from airflow.models.taskinstancehistory import TaskInstanceHistory
Expand Down
14 changes: 11 additions & 3 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def set_current_context(context: Context) -> Generator[Context, None, None]:
)


def _stop_remaining_tasks(*, task_instance: TaskInstance, session: Session):
def _stop_remaining_tasks(*, task_instance: TaskInstance, task_teardown_map=None, session: Session):
"""
Stop non-teardown tasks in dag.

Expand All @@ -411,13 +411,21 @@ def _stop_remaining_tasks(*, task_instance: TaskInstance, session: Session):
TaskInstanceState.FAILED,
):
continue
task = task_instance.task.dag.task_dict[ti.task_id]
if not task.is_teardown:
if task_teardown_map:
teardown = task_teardown_map[ti.task_id]
else:
task = task_instance.task.dag.task_dict[ti.task_id]
teardown = task.is_teardown
if not teardown:
if ti.state == TaskInstanceState.RUNNING:
log.info("Forcing task %s to fail due to dag's `fail_fast` setting", ti.task_id)
msg = "Forcing task to fail due to dag's `fail_fast` setting."
session.add(Log(event="fail task", extra=msg, task_instance=ti.key))
ti.error(session)
else:
log.info("Setting task %s to SKIPPED due to dag's `fail_fast` setting.", ti.task_id)
msg = "Skipping task due to dag's `fail_fast` setting."
session.add(Log(event="skip task", extra=msg, task_instance=ti.key))
ti.set_state(state=TaskInstanceState.SKIPPED, session=session)
else:
log.info("Not skipping teardown task '%s'", ti.task_id)
Expand Down
1 change: 1 addition & 0 deletions airflow-core/src/airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
}
},
"catchup": { "type": "boolean" },
"fail_fast": { "type": "boolean" },
"fileloc": { "type" : "string"},
"relative_fileloc": { "type" : "string"},
"_processor_dags_folder": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,8 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta
mock.patch(
"airflow.api_fastapi.common.db.common.Session.execute",
side_effect=[
mock.Mock(one=lambda: ("running", 1, 0)), # First call returns "queued"
mock.Mock(one=lambda: ("running", 1, 0)), # Second call returns "queued"
mock.Mock(one=lambda: ("running", 1, 0, "dag")), # First call returns "queued"
mock.Mock(one=lambda: ("running", 1, 0, "dag")), # Second call returns "queued"
SQLAlchemyError("Database error"), # Last call raises an error
],
),
Expand Down
16 changes: 3 additions & 13 deletions airflow-core/tests/unit/api_fastapi/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,38 +32,30 @@ def test_main_app_lifespan(client):
assert test_app.state.lifespan_called, "Lifespan not called on Execution API app."


@mock.patch("airflow.api_fastapi.app.init_dag_bag")
@mock.patch("airflow.api_fastapi.app.init_views")
@mock.patch("airflow.api_fastapi.app.init_plugins")
@mock.patch("airflow.api_fastapi.app.create_task_execution_api_app")
def test_core_api_app(
mock_create_task_exec_api, mock_init_plugins, mock_init_views, mock_init_dag_bag, client
):
def test_core_api_app(mock_create_task_exec_api, mock_init_plugins, mock_init_views, client):
test_app = client(apps="core").app

# Assert that core-related functions were called
mock_init_dag_bag.assert_called_once_with(test_app)
mock_init_views.assert_called_once_with(test_app)
mock_init_plugins.assert_called_once_with(test_app)

# Assert that execution-related functions were NOT called
mock_create_task_exec_api.assert_not_called()


@mock.patch("airflow.api_fastapi.app.init_dag_bag")
@mock.patch("airflow.api_fastapi.app.init_views")
@mock.patch("airflow.api_fastapi.app.init_plugins")
@mock.patch("airflow.api_fastapi.app.create_task_execution_api_app")
def test_execution_api_app(
mock_create_task_exec_api, mock_init_plugins, mock_init_views, mock_init_dag_bag, client
):
def test_execution_api_app(mock_create_task_exec_api, mock_init_plugins, mock_init_views, client):
client(apps="execution")

# Assert that execution-related functions were called
mock_create_task_exec_api.assert_called_once()

# Assert that core-related functions were NOT called
mock_init_dag_bag.assert_not_called()
mock_init_views.assert_not_called()
mock_init_plugins.assert_not_called()

Expand All @@ -78,15 +70,13 @@ def test_execution_api_app_lifespan(client):
assert execution_app[0].state.lifespan_called, "Lifespan not called on Execution API app."


@mock.patch("airflow.api_fastapi.app.init_dag_bag")
@mock.patch("airflow.api_fastapi.app.init_views")
@mock.patch("airflow.api_fastapi.app.init_plugins")
@mock.patch("airflow.api_fastapi.app.create_task_execution_api_app")
def test_all_apps(mock_create_task_exec_api, mock_init_plugins, mock_init_views, mock_init_dag_bag, client):
def test_all_apps(mock_create_task_exec_api, mock_init_plugins, mock_init_views, client):
test_app = client(apps="all").app

# Assert that core-related functions were called
mock_init_dag_bag.assert_called_once_with(test_app)
mock_init_views.assert_called_once_with(test_app)
mock_init_plugins.assert_called_once_with(test_app)

Expand Down
1 change: 0 additions & 1 deletion task-sdk/src/airflow/sdk/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,6 @@ def _validate_owner_links(self, _, owner_links):
"has_on_success_callback",
"has_on_failure_callback",
"auto_register",
"fail_fast",
"schedule",
}

Expand Down
Loading