Skip to content
Open
8 changes: 8 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,13 @@ class SetRenderedMapIndex(BaseModel):
type: Literal["SetRenderedMapIndex"] = "SetRenderedMapIndex"


class TaskExecutionTimeout(BaseModel):
"""Payload for communicating task execution timeout to supervisor."""

timeout_seconds: float
type: Literal["TaskExecutionTimeout"] = "TaskExecutionTimeout"


class TriggerDagRun(TriggerDAGRunPayload):
dag_id: str
run_id: Annotated[str, Field(title="Dag Run Id")]
Expand Down Expand Up @@ -1046,6 +1053,7 @@ class MaskSecret(BaseModel):
| RetryTask
| SetRenderedFields
| SetRenderedMapIndex
| TaskExecutionTimeout
| SetXCom
| SkipDownstreamTasks
| SucceedTask
Expand Down
55 changes: 55 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
StartupDetails,
SucceedTask,
TaskBreadcrumbsResult,
TaskExecutionTimeout,
TaskState,
TaskStatesResult,
ToSupervisor,
Expand Down Expand Up @@ -961,6 +962,16 @@ class ActivitySubprocess(WatchedSubprocess):
_task_end_time_monotonic: float | None = attrs.field(default=None, init=False)
_rendered_map_index: str | None = attrs.field(default=None, init=False)

# Execution timeout tracking
_execution_timeout_seconds: float | None = attrs.field(default=None, init=False)
"""Task execution timeout in seconds, received from the task process."""

_task_execution_start_time: float | None = attrs.field(default=None, init=False)
"""Monotonic time when task execution actually started (after parsing DAG)."""

_timeout_sigterm_sent_at: float | None = attrs.field(default=None, init=False)
"""Monotonic time when SIGTERM was sent due to timeout."""
Comment on lines +965 to +973
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This brings the question if 0.0 in any of these values affects logic. You use is None/is not None in some checks, but simple boolean checks in others, so this is not immediately clear. This needs to be treated carefully.

Copy link
Author

@qwe-kev qwe-kev Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right! I've updated the code to use explicit is None checks throughout to handle the edge case where these float values could technically be 0.0 (which is falsy in Python). Made the following changes:

  1. supervisor.py - Changed falsy checks to explicit None checks:

    • Line 1135: if not self._execution_timeout_seconds → if self._execution_timeout_seconds is None

    • Line 1138: if not self._task_execution_start_time → if self._task_execution_start_time is None

  2. task_runner.py - Added validation to only send positive timeouts:

    • Line 761: Added if timeout_seconds > 0: check before sending timeout to supervisor


decoder: ClassVar[TypeAdapter[ToSupervisor]] = TypeAdapter(ToSupervisor)

ti: RuntimeTI | None = None
Expand Down Expand Up @@ -1122,6 +1133,42 @@ def _monitor_subprocess(self):
self._send_heartbeat_if_needed()

self._handle_process_overtime_if_needed()
self._handle_execution_timeout_if_needed()

def _handle_execution_timeout_if_needed(self):
"""Handle task execution timeout by sending SIGTERM, then SIGKILL if needed."""
# Only check timeout if we have received timeout configuration and task is still running
if self._execution_timeout_seconds is None or self._terminal_state:
return

if self._task_execution_start_time is None:
return

elapsed_time = time.monotonic() - self._task_execution_start_time

# Check if we've exceeded the execution timeout
if elapsed_time > self._execution_timeout_seconds:
# Grace period for SIGKILL after SIGTERM (5 seconds)
SIGKILL_GRACE_PERIOD = 5.0

if self._timeout_sigterm_sent_at is None:
# First timeout - send SIGTERM
log.warning(
"Task execution timeout exceeded; sending SIGTERM",
ti_id=self.id,
timeout_seconds=self._execution_timeout_seconds,
elapsed_seconds=elapsed_time,
)
self.kill(signal.SIGTERM)
self._timeout_sigterm_sent_at = time.monotonic()
elif (time.monotonic() - self._timeout_sigterm_sent_at) > SIGKILL_GRACE_PERIOD:
# SIGTERM didn't work, escalate to SIGKILL
log.error(
"Task did not respond to SIGTERM; sending SIGKILL",
ti_id=self.id,
grace_period_seconds=SIGKILL_GRACE_PERIOD,
)
self.kill(signal.SIGKILL, force=True)

def _handle_process_overtime_if_needed(self):
"""Handle termination of auxiliary processes if the task exceeds the configured overtime."""
Expand Down Expand Up @@ -1234,6 +1281,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
self._rendered_map_index = msg.rendered_map_index
elif isinstance(msg, TaskExecutionTimeout):
# Task has sent us its execution timeout configuration
self._execution_timeout_seconds = msg.timeout_seconds
self._task_execution_start_time = time.monotonic()
log.info(
"Received task execution timeout from task process",
timeout_seconds=msg.timeout_seconds,
)
elif isinstance(msg, SucceedTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
Expand Down
31 changes: 13 additions & 18 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import time
from collections.abc import Callable, Iterable, Iterator, Mapping
from contextlib import suppress
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from itertools import product
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal
Expand Down Expand Up @@ -95,6 +95,7 @@
StartupDetails,
SucceedTask,
TaskBreadcrumbsResult,
TaskExecutionTimeout,
TaskRescheduleStartDate,
TaskState,
TaskStatesResult,
Expand Down Expand Up @@ -830,6 +831,14 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
ti = parse(msg, log)
log.debug("Dag file parsed", file=msg.dag_rel_path)

# Send execution timeout to supervisor if configured
timeout = ti.task.execution_timeout
if isinstance(timeout, timedelta):
timeout_seconds = timeout.total_seconds()
if timeout_seconds > 0:
SUPERVISOR_COMMS.send(TaskExecutionTimeout(timeout_seconds=timeout_seconds))
log.debug("Sent execution timeout to supervisor", timeout_seconds=timeout_seconds)

run_as_user = getattr(ti.task, "run_as_user", None) or conf.get(
"core", "default_impersonation", fallback=None
)
Expand Down Expand Up @@ -1520,23 +1529,9 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):

_run_task_state_change_callbacks(task, "on_execute_callback", context, log)

if task.execution_timeout:
from airflow.sdk.execution_time.timeout import timeout

# TODO: handle timeout in case of deferral
timeout_seconds = task.execution_timeout.total_seconds()
try:
# It's possible we're already timed out, so fast-fail if true
if timeout_seconds <= 0:
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = ctx.run(execute, context=context)
except AirflowTaskTimeout:
task.on_kill()
raise
else:
result = ctx.run(execute, context=context)
# Timeout is now handled at supervisor level, not in the task process
# The supervisor monitors execution time and sends SIGTERM/SIGKILL if exceeded
result = ctx.run(execute, context=context)

if (post_execute_hook := task._post_execute_hook) is not None:
create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result)
Expand Down
7 changes: 7 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
SkipDownstreamTasks,
SucceedTask,
TaskBreadcrumbsResult,
TaskExecutionTimeout,
TaskRescheduleStartDate,
TaskState,
TaskStatesResult,
Expand Down Expand Up @@ -2461,6 +2462,12 @@ class RequestTestCase:
},
test_id="get_task_breadcrumbs",
),
RequestTestCase(
message=TaskExecutionTimeout(timeout_seconds=300.0),
test_id="task_execution_timeout",
client_mock=None,
expected_body=None,
),
]


Expand Down
43 changes: 0 additions & 43 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
AirflowSensorTimeout,
AirflowSkipException,
AirflowTaskTerminated,
AirflowTaskTimeout,
DownstreamTasksSkipped,
ErrorType,
TaskDeferred,
Expand Down Expand Up @@ -656,47 +655,6 @@ def test_run_raises_airflow_exception(time_machine, create_runtime_ti, mock_supe
mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant))


def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms):
"""Test running a basic task that times out."""
from time import sleep

task = PythonOperator(
task_id="sleep",
execution_timeout=timedelta(milliseconds=10),
python_callable=lambda: sleep(2),
)

ti = create_runtime_ti(task=task, dag_id="basic_dag_time_out")

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

run(ti, context=ti.get_template_context(), log=mock.MagicMock())

assert ti.state == TaskInstanceState.FAILED

# this state can only be reached if the try block passed down the exception to handler of AirflowTaskTimeout
mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant))


def test_execution_timeout(create_runtime_ti):
def sleep_and_catch_other_exceptions():
with contextlib.suppress(Exception):
# Catching Exception should NOT catch AirflowTaskTimeout
time.sleep(5)

op = PythonOperator(
task_id="test_timeout",
execution_timeout=timedelta(seconds=1),
python_callable=sleep_and_catch_other_exceptions,
)

ti = create_runtime_ti(task=op, dag_id="dag_execution_timeout")

with pytest.raises(AirflowTaskTimeout):
_execute_task(context=ti.get_template_context(), ti=ti, log=mock.MagicMock())


def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms, spy_agency):
"""Test running a Dag with templated task."""
from airflow.providers.standard.operators.bash import BashOperator
Expand Down Expand Up @@ -3753,7 +3711,6 @@ def execute(self, context):

def test_task_runner_both_callbacks_have_timing_info(self, create_runtime_ti):
"""Test that both success and failure callbacks receive accurate timing information."""
import time

success_data = {}
failure_data = {}
Expand Down
Loading