From af32d1477cc7488166c9f4ab77cd2423abea4b8c Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 10 Jul 2025 20:11:34 +0530 Subject: [PATCH] Run Task failure callbacks on DAG Processor when task is externally killed (#53058) Until https://github.com/apache/airflow/issues/44354 is implemented, tasks killed externally or when supervisor process dies unexpectedly, users have no way of knowing this happened. This has been a blocker for Airflow 3.0 adoption for some: - https://github.com/apache/airflow/issues/44354 - https://apache-airflow.slack.com/archives/C07813CNKA8/p1751057525231389 https://github.com/apache/airflow/issues/44354 is more involved and we might not get to it for Airflow 3.1 -- so this is a good fix until then similar to how we run Dag Run callback. (cherry-picked from a5211f2efd5ccc565cbc16baee6144dba09918bc) --- .../execution_api/datamodels/taskinstance.py | 4 +- .../airflow/callbacks/callback_requests.py | 2 + .../src/airflow/dag_processing/processor.py | 72 +++- .../src/airflow/jobs/scheduler_job_runner.py | 18 +- .../unit/callbacks/test_callback_requests.py | 29 +- .../unit/dag_processing/test_processor.py | 328 +++++++++++++++--- .../tests/unit/jobs/test_scheduler_job.py | 141 ++++---- .../airflow/sdk/api/datamodels/_generated.py | 2 +- 8 files changed, 472 insertions(+), 124 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index c43c931f3e28a..2d7968bbc625e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -302,7 +302,7 @@ class TIRunContext(BaseModel): dag_run: DagRun """DAG run information for the task instance.""" - task_reschedule_count: Annotated[int, Field(default=0)] + task_reschedule_count: int = 0 """How many times the task has been rescheduled.""" max_tries: int @@ -328,7 +328,7 @@ class TIRunContext(BaseModel): xcom_keys_to_clear: Annotated[list[str], Field(default_factory=list)] """List of Xcom keys that need to be cleared and purged on by the worker.""" - should_retry: bool + should_retry: bool = False """If the ti encounters an error, whether it should enter retry or failed state.""" diff --git a/airflow-core/src/airflow/callbacks/callback_requests.py b/airflow-core/src/airflow/callbacks/callback_requests.py index 8cf8c77035737..3220497a20900 100644 --- a/airflow-core/src/airflow/callbacks/callback_requests.py +++ b/airflow-core/src/airflow/callbacks/callback_requests.py @@ -61,6 +61,8 @@ class TaskCallbackRequest(BaseCallbackRequest): """Simplified Task Instance representation""" task_callback_type: TaskInstanceState | None = None """Whether on success, on failure, on retry""" + context_from_server: ti_datamodel.TIRunContext | None = None + """Task execution context from the Server""" type: Literal["TaskCallbackRequest"] = "TaskCallbackRequest" @property diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index a54c56ddbe27e..36022c61dfd35 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -16,12 +16,14 @@ # under the License. from __future__ import annotations +import contextlib import importlib import os import sys import traceback +from collections.abc import Callable, Sequence from pathlib import Path -from typing import TYPE_CHECKING, Annotated, BinaryIO, Callable, ClassVar, Literal, Union +from typing import TYPE_CHECKING, Annotated, BinaryIO, ClassVar, Literal, Union import attrs from pydantic import BaseModel, Field, TypeAdapter @@ -44,9 +46,11 @@ VariableResult, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess +from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG from airflow.stats import Stats from airflow.utils.file import iter_airflow_imports +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger @@ -200,10 +204,7 @@ def _execute_callbacks( for request in callback_requests: log.debug("Processing Callback Request", request=request.to_json()) if isinstance(request, TaskCallbackRequest): - raise NotImplementedError( - "Haven't coded Task callback yet - https://github.com/apache/airflow/issues/44354!" - ) - # _execute_task_callbacks(dagbag, request) + _execute_task_callbacks(dagbag, request, log) if isinstance(request, DagCallbackRequest): _execute_dag_callbacks(dagbag, request, log) @@ -237,6 +238,67 @@ def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: Fil Stats.incr("dag.callback_exceptions", tags={"dag_id": request.dag_id}) +def _execute_task_callbacks(dagbag: DagBag, request: TaskCallbackRequest, log: FilteringBoundLogger) -> None: + if not request.is_failure_callback: + log.warning( + "Task callback requested but is not a failure callback", + dag_id=request.ti.dag_id, + task_id=request.ti.task_id, + run_id=request.ti.run_id, + ) + return + + dag = dagbag.dags[request.ti.dag_id] + task = dag.get_task(request.ti.task_id) + + if request.task_callback_type is TaskInstanceState.UP_FOR_RETRY: + callbacks = task.on_retry_callback + else: + callbacks = task.on_failure_callback + + if not callbacks: + log.warning( + "Callback requested but no callback found", + dag_id=request.ti.dag_id, + task_id=request.ti.task_id, + run_id=request.ti.run_id, + ti_id=request.ti.id, + ) + return + + callbacks = callbacks if isinstance(callbacks, Sequence) else [callbacks] + ctx_from_server = request.context_from_server + + if ctx_from_server is not None: + runtime_ti = RuntimeTaskInstance.model_construct( + **request.ti.model_dump(exclude_unset=True), + task=task, + _ti_context_from_server=ctx_from_server, + max_tries=ctx_from_server.max_tries, + ) + else: + runtime_ti = RuntimeTaskInstance.model_construct( + **request.ti.model_dump(exclude_unset=True), + task=task, + ) + context = runtime_ti.get_template_context() + + def get_callback_representation(callback): + with contextlib.suppress(AttributeError): + return callback.__name__ + with contextlib.suppress(AttributeError): + return callback.__class__.__name__ + return callback + + for idx, callback in enumerate(callbacks): + callback_repr = get_callback_representation(callback) + log.info("Executing Task callback at index %d: %s", idx, callback_repr) + try: + callback(context) + except Exception: + log.exception("Error in callback at index %d: %s", idx, callback_repr) + + def in_process_api_server() -> InProcessExecutionAPI: from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 4185cbf6d4425..0fa598be412e1 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -38,6 +38,7 @@ from sqlalchemy.sql import expression from airflow import settings +from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest from airflow.configuration import conf from airflow.dag_processing.bundles.base import BundleUsageTrackingManager @@ -945,10 +946,16 @@ def process_executor_events( bundle_version=ti.dag_version.bundle_version, ti=ti, msg=msg, + context_from_server=TIRunContext( + dag_run=ti.dag_run, + max_tries=ti.max_tries, + variables=[], + connections=[], + xcom_keys_to_clear=[], + ), ) executor.send_callback(request) - else: - ti.handle_failure(error=msg, session=session) + ti.handle_failure(error=msg, session=session) return len(event_buffer) @@ -2296,6 +2303,13 @@ def _purge_task_instances_without_heartbeats( bundle_version=ti.dag_run.bundle_version, ti=ti, msg=str(task_instance_heartbeat_timeout_message_details), + context_from_server=TIRunContext( + dag_run=ti.dag_run, + max_tries=ti.max_tries, + variables=[], + connections=[], + xcom_keys_to_clear=[], + ), ) session.add( Log( diff --git a/airflow-core/tests/unit/callbacks/test_callback_requests.py b/airflow-core/tests/unit/callbacks/test_callback_requests.py index 37a7a3023d872..d27b7ee343c66 100644 --- a/airflow-core/tests/unit/callbacks/test_callback_requests.py +++ b/airflow-core/tests/unit/callbacks/test_callback_requests.py @@ -28,7 +28,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.bash import BashOperator from airflow.utils import timezone -from airflow.utils.state import State +from airflow.utils.state import State, TaskInstanceState pytestmark = pytest.mark.db_test @@ -85,3 +85,30 @@ def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create json_str = input.to_json() result = TaskCallbackRequest.from_json(json_str) assert input == result + + @pytest.mark.parametrize( + "task_callback_type,expected_is_failure", + [ + (None, True), + (TaskInstanceState.FAILED, True), + (TaskInstanceState.UP_FOR_RETRY, True), + (TaskInstanceState.UPSTREAM_FAILED, True), + (TaskInstanceState.SUCCESS, False), + (TaskInstanceState.RUNNING, False), + ], + ) + def test_is_failure_callback_property( + self, task_callback_type, expected_is_failure, create_task_instance + ): + """Test is_failure_callback property with different task callback types""" + ti = create_task_instance() + + request = TaskCallbackRequest( + filepath="filepath", + ti=ti, + bundle_name="testing", + bundle_version=None, + task_callback_type=task_callback_type, + ) + + assert request.is_failure_callback == expected_is_failure diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index e8ab15a02227c..71b601dfdd8da 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -21,8 +21,10 @@ import pathlib import sys import textwrap +import uuid +from collections.abc import Callable from socket import socketpair -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import pytest @@ -30,24 +32,27 @@ from pydantic import TypeAdapter from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI +from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( + TaskInstance as TIDataModel, + TIRunContext, +) from airflow.callbacks.callback_requests import CallbackRequest, DagCallbackRequest, TaskCallbackRequest -from airflow.configuration import conf from airflow.dag_processing.processor import ( DagFileParseRequest, DagFileParsingResult, DagFileProcessorProcess, + _execute_task_callbacks, _parse_file, _pre_import_airflow_modules, ) -from airflow.models import DagBag, TaskInstance +from airflow.models import DagBag, DagRun from airflow.models.baseoperator import BaseOperator -from airflow.models.serialized_dag import SerializedDagModel +from airflow.sdk import DAG from airflow.sdk.api.client import Client from airflow.sdk.execution_time import comms from airflow.utils import timezone from airflow.utils.session import create_session -from airflow.utils.state import DagRunState, TaskInstanceState -from airflow.utils.types import DagRunTriggeredByType, DagRunType +from airflow.utils.state import TaskInstanceState from tests_common.test_utils.config import conf_vars, env_vars @@ -93,42 +98,6 @@ def _process_file( log=structlog.get_logger(), ) - @pytest.mark.xfail(reason="TODO: AIP-72") - @pytest.mark.parametrize( - ["has_serialized_dag"], - [pytest.param(True, id="dag_in_db"), pytest.param(False, id="no_dag_found")], - ) - @patch.object(TaskInstance, "handle_failure") - def test_execute_on_failure_callbacks_without_dag(self, mock_ti_handle_failure, has_serialized_dag): - dagbag = DagBag(dag_folder="/dev/null", include_examples=True, read_dags_from_db=False) - with create_session() as session: - session.query(TaskInstance).delete() - dag = dagbag.get_dag("example_branch_operator") - assert dag is not None - dag.sync_to_db() - dagrun = dag.create_dagrun( - state=DagRunState.RUNNING, - logical_date=DEFAULT_DATE, - run_type=DagRunType.SCHEDULED, - data_interval=dag.infer_automated_data_interval(DEFAULT_DATE), - run_after=DEFAULT_DATE, - triggered_by=DagRunTriggeredByType.TEST, - session=session, - ) - task = dag.get_task(task_id="run_this_first") - ti = TaskInstance(task, run_id=dagrun.run_id, state=TaskInstanceState.QUEUED) - session.add(ti) - - if has_serialized_dag: - assert SerializedDagModel.write_dag(dag, bundle_name="testing", session=session) is True - session.flush() - - requests = [TaskCallbackRequest(full_filepath="A", ti=ti, msg="Message")] - self._process_file(dag.fileloc, requests) - mock_ti_handle_failure.assert_called_once_with( - error="Message", test_mode=conf.getboolean("core", "unit_test_mode"), session=session - ) - def test_dagbag_import_errors_captured(self, spy_agency: SpyAgency): @spy_agency.spy_for(DagBag.collect_dags, owner=DagBag) def fake_collect_dags(dagbag: DagBag, *args, **kwargs): @@ -554,10 +523,7 @@ def fake_collect_dags(self, *args, **kwargs): assert called is True -@pytest.mark.xfail(reason="TODO: AIP-72: Task level callbacks not yet supported") def test_parse_file_with_task_callbacks(spy_agency): - from airflow import DAG - called = False def on_failure(context): @@ -572,15 +538,283 @@ def fake_collect_dags(self, *args, **kwargs): spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + # Create a minimal TaskInstance for the request + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="a", + task_id="b", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + ) + requests = [ TaskCallbackRequest( filepath="A", msg="Message", - ti=None, + ti=ti_data, bundle_name="testing", bundle_version=None, ) ] - _parse_file(DagFileParseRequest(file="A", callback_requests=requests), log=structlog.get_logger()) + _parse_file( + DagFileParseRequest(file="A", bundle_path="test", callback_requests=requests), + log=structlog.get_logger(), + ) assert called is True + + +class TestExecuteTaskCallbacks: + """Test the _execute_task_callbacks function""" + + def test_execute_task_callbacks_failure_callback(self, spy_agency): + """Test _execute_task_callbacks executes failure callbacks""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task failed", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.FAILED, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + assert context_received["dag"] == dag + assert "ti" in context_received + + def test_execute_task_callbacks_retry_callback(self, spy_agency): + """Test _execute_task_callbacks executes retry callbacks""" + called = False + context_received = None + + def on_retry(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + BaseOperator(task_id="test_task", on_retry_callback=on_retry) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + state=TaskInstanceState.UP_FOR_RETRY, + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task retrying", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.UP_FOR_RETRY, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + assert context_received["dag"] == dag + assert "ti" in context_received + + def test_execute_task_callbacks_with_context_from_server(self, spy_agency): + """Test _execute_task_callbacks with context_from_server creates full context""" + called = False + context_received = None + + def on_failure(context): + nonlocal called, context_received + called = True + context_received = context + + with DAG(dag_id="test_dag") as dag: + BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + dag_run = DagRun( + dag_id="test_dag", + run_id="test_run", + logical_date=timezone.utcnow(), + start_date=timezone.utcnow(), + run_type="manual", + ) + dag_run.run_after = timezone.utcnow() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + ) + + context_from_server = TIRunContext( + dag_run=dag_run, + max_tries=3, + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task failed", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.FAILED, + context_from_server=context_from_server, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert called is True + assert context_received is not None + # When context_from_server is provided, we get a full RuntimeTaskInstance context + assert "dag_run" in context_received + assert "logical_date" in context_received + + def test_execute_task_callbacks_not_failure_callback(self, spy_agency): + """Test _execute_task_callbacks when request is not a failure callback""" + called = False + + def on_failure(context): + nonlocal called + called = True + + with DAG(dag_id="test_dag") as dag: + BaseOperator(task_id="test_task", on_failure_callback=on_failure) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + state=TaskInstanceState.SUCCESS, + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task succeeded", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.SUCCESS, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + # Should not call the callback since it's not a failure callback + assert called is False + + def test_execute_task_callbacks_multiple_callbacks(self, spy_agency): + """Test _execute_task_callbacks with multiple callbacks""" + call_count = 0 + + def on_failure_1(context): + nonlocal call_count + call_count += 1 + + def on_failure_2(context): + nonlocal call_count + call_count += 1 + + with DAG(dag_id="test_dag") as dag: + BaseOperator(task_id="test_task", on_failure_callback=[on_failure_1, on_failure_2]) + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + ti_data = TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + try_number=1, + dag_version_id=uuid.uuid4(), + state=TaskInstanceState.FAILED, + ) + + request = TaskCallbackRequest( + filepath="test.py", + msg="Task failed", + ti=ti_data, + bundle_name="testing", + bundle_version=None, + task_callback_type=TaskInstanceState.FAILED, + ) + + log = structlog.get_logger() + _execute_task_callbacks(dagbag, request, log) + + assert call_count == 2 diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index bced9e15be491..30a3ba9800769 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -408,9 +408,7 @@ def test_process_executor_events_with_callback( self.job_runner._process_executor_events(executor=executor, session=session) ti1.refresh_from_db() - # The state will remain in queued here and - # will be set to failed in dag parsing process - assert ti1.state == State.QUEUED + assert ti1.state == State.FAILED mock_task_callback.assert_called_once_with( filepath=dag.relative_fileloc, ti=mock.ANY, @@ -420,10 +418,11 @@ def test_process_executor_events_with_callback( " " "finished with state failed, but the task instance's state attribute is queued. " "Learn more: https://airflow.apache.org/docs/apache-airflow/stable/troubleshooting.html#task-state-changed-externally", + context_from_server=mock.ANY, ) scheduler_job.executor.callback_sink.send.assert_called_once_with(task_callback) scheduler_job.executor.callback_sink.reset_mock() - mock_stats_incr.assert_called_once_with( + mock_stats_incr.assert_any_call( "scheduler.tasks.killed_externally", tags={ "dag_id": "test_process_executor_events_with_callback", @@ -5880,6 +5879,11 @@ def test_find_and_purge_task_instances_without_heartbeats(self, session, create_ assert callback_request.ti.run_id == ti.run_id assert callback_request.ti.map_index == ti.map_index + # Verify context_from_server is passed + assert callback_request.context_from_server is not None + assert callback_request.context_from_server.dag_run.logical_date == ti.dag_run.logical_date + assert callback_request.context_from_server.max_tries == ti.max_tries + @pytest.mark.usefixtures("testing_dag_bundle") def test_task_instance_heartbeat_timeout_message(self, session, create_dagrun): """ @@ -5947,68 +5951,6 @@ def test_task_instance_heartbeat_timeout_message(self, session, create_dagrun): "External Executor Id": "abcdefg", } - @pytest.mark.usefixtures("testing_dag_bundle") - def test_find_task_instances_without_heartbeats_handle_failure_callbacks_are_correctly_passed_to_dag_processor( - self, create_dagrun, session - ): - """ - Check that the same set of failure callbacks for task instances without heartbeats are passed to the dag - file processors until the next task instance heartbeat timeout detection logic is invoked. - """ - with conf_vars({("core", "load_examples"): "False"}): - dagbag = DagBag( - dag_folder=os.path.join(settings.DAGS_FOLDER, "test_example_bash_operator.py"), - read_dags_from_db=False, - ) - session.query(Job).delete() - dag = dagbag.get_dag("test_example_bash_operator") - DAG.bulk_write_to_db("testing", None, [dag]) - SerializedDagModel.write_dag(dag=dag, bundle_name="testing") - data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE) - dag_run = create_dagrun( - dag, - state=DagRunState.RUNNING, - logical_date=DEFAULT_DATE, - run_type=DagRunType.SCHEDULED, - data_interval=data_interval, - ) - task = dag.get_task(task_id="run_this_last") - dag_version_id = DagVersion.get_latest_version(dag.dag_id).id - ti = TaskInstance(task, run_id=dag_run.run_id, state=State.RUNNING, dag_version_id=dag_version_id) - ti.last_heartbeat_at = timezone.utcnow() - timedelta(minutes=6) - ti.start_date = timezone.utcnow() - timedelta(minutes=10) - - # TODO: If there was an actual Relationship between TI and Job - # we wouldn't need this extra commit - session.add(ti) - session.flush() - - scheduler_job = Job(executor=self.null_exec) - self.job_runner = SchedulerJobRunner(job=scheduler_job) - - self.job_runner._find_and_purge_task_instances_without_heartbeats() - - scheduler_job.executor.callback_sink.send.assert_called_once() - - expected_failure_callback_requests = [ - TaskCallbackRequest( - filepath=dag.relative_fileloc, - ti=ti, - msg=str(self.job_runner._generate_task_instance_heartbeat_timeout_message_details(ti)), - bundle_name="testing", - bundle_version=dag_run.bundle_version, - ) - ] - callback_requests = scheduler_job.executor.callback_sink.send.call_args.args - assert len(callback_requests) == 1 - assert { - task_instances_without_heartbeats.ti.id - for task_instances_without_heartbeats in expected_failure_callback_requests - } == {result.ti.id for result in callback_requests} - expected_failure_callback_requests[0].ti = None - callback_requests[0].ti = None - assert expected_failure_callback_requests[0] == callback_requests[0] - @mock.patch.object(settings, "USE_JOB_SCHEDULE", False) def run_scheduler_until_dagrun_terminal(self): """ @@ -6519,6 +6461,73 @@ def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag( for i in range(100): assert f"it's duplicate {i}" in dag_warning.message + def test_scheduler_passes_context_from_server_on_heartbeat_timeout(self, dag_maker, session): + """Test that scheduler passes context_from_server when handling heartbeat timeouts.""" + with dag_maker(dag_id="test_dag", session=session): + EmptyOperator(task_id="test_task") + + dag_run = dag_maker.create_dagrun(run_id="test_run", state=DagRunState.RUNNING) + + mock_executor = MagicMock() + scheduler_job = Job(executor=mock_executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + + # Create a task instance that appears to be running but hasn't heartbeat + ti = dag_run.get_task_instance(task_id="test_task") + ti.state = TaskInstanceState.RUNNING + ti.queued_by_job_id = scheduler_job.id + # Set last_heartbeat_at to a time that would trigger timeout + ti.last_heartbeat_at = timezone.utcnow() - timedelta(seconds=600) # 10 minutes ago + session.merge(ti) + session.commit() + + # Run the heartbeat timeout check + self.job_runner._find_and_purge_task_instances_without_heartbeats() + + # Verify TaskCallbackRequest was created with context_from_server + mock_executor.send_callback.assert_called_once() + callback_request = mock_executor.send_callback.call_args[0][0] + + assert isinstance(callback_request, TaskCallbackRequest) + assert callback_request.context_from_server is not None + assert callback_request.context_from_server.dag_run.logical_date == dag_run.logical_date + assert callback_request.context_from_server.max_tries == ti.max_tries + + def test_scheduler_passes_context_from_server_on_task_failure(self, dag_maker, session): + """Test that scheduler passes context_from_server when handling task failures.""" + with dag_maker(dag_id="test_dag", session=session): + EmptyOperator(task_id="test_task", on_failure_callback=lambda: print("failure")) + + dag_run = dag_maker.create_dagrun(run_id="test_run", state=DagRunState.RUNNING) + + # Create a task instance that's running + ti = dag_run.get_task_instance(task_id="test_task") + ti.state = TaskInstanceState.RUNNING + session.merge(ti) + session.commit() + + # Mock the executor to simulate a task failure + mock_executor = MagicMock(spec=BaseExecutor) + mock_executor.has_task = mock.MagicMock(return_value=False) + scheduler_job = Job(executor=mock_executor) + self.job_runner = SchedulerJobRunner(scheduler_job) + + # Simulate executor reporting task as failed + executor_event = {ti.key: (TaskInstanceState.FAILED, None)} + mock_executor.get_event_buffer.return_value = executor_event + + # Process the executor events + self.job_runner._process_executor_events(mock_executor, session) + + # Verify TaskCallbackRequest was created with context_from_server + mock_executor.send_callback.assert_called_once() + callback_request = mock_executor.send_callback.call_args[0][0] + + assert isinstance(callback_request, TaskCallbackRequest) + assert callback_request.context_from_server is not None + assert callback_request.context_from_server.dag_run.logical_date == dag_run.logical_date + assert callback_request.context_from_server.max_tries == ti.max_tries + @pytest.mark.need_serialized_dag def test_schedule_dag_run_with_upstream_skip(dag_maker, session): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index ac1e51d5e55c9..1dabd8c90228d 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -518,7 +518,7 @@ class TIRunContext(BaseModel): next_method: Annotated[str | None, Field(title="Next Method")] = None next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None xcom_keys_to_clear: Annotated[list[str] | None, Field(title="Xcom Keys To Clear")] = None - should_retry: Annotated[bool, Field(title="Should Retry")] + should_retry: Annotated[bool | None, Field(title="Should Retry")] = False class TITerminalStatePayload(BaseModel):