diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 15755e640d97e..3cececdb14df9 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -877,6 +877,13 @@ class SetRenderedMapIndex(BaseModel): type: Literal["SetRenderedMapIndex"] = "SetRenderedMapIndex" +class TaskExecutionTimeout(BaseModel): + """Payload for communicating task execution timeout to supervisor.""" + + timeout_seconds: float + type: Literal["TaskExecutionTimeout"] = "TaskExecutionTimeout" + + class TriggerDagRun(TriggerDAGRunPayload): dag_id: str run_id: Annotated[str, Field(title="Dag Run Id")] @@ -1046,6 +1053,7 @@ class MaskSecret(BaseModel): | RetryTask | SetRenderedFields | SetRenderedMapIndex + | TaskExecutionTimeout | SetXCom | SkipDownstreamTasks | SucceedTask diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index b87131aa7336d..bc356f7fd4d64 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -108,6 +108,7 @@ StartupDetails, SucceedTask, TaskBreadcrumbsResult, + TaskExecutionTimeout, TaskState, TaskStatesResult, ToSupervisor, @@ -961,6 +962,16 @@ class ActivitySubprocess(WatchedSubprocess): _task_end_time_monotonic: float | None = attrs.field(default=None, init=False) _rendered_map_index: str | None = attrs.field(default=None, init=False) + # Execution timeout tracking + _execution_timeout_seconds: float | None = attrs.field(default=None, init=False) + """Task execution timeout in seconds, received from the task process.""" + + _task_execution_start_time: float | None = attrs.field(default=None, init=False) + """Monotonic time when task execution actually started (after parsing DAG).""" + + _timeout_sigterm_sent_at: float | None = attrs.field(default=None, init=False) + """Monotonic time when SIGTERM was sent due to timeout.""" + decoder: ClassVar[TypeAdapter[ToSupervisor]] = TypeAdapter(ToSupervisor) ti: RuntimeTI | None = None @@ -1122,6 +1133,42 @@ def _monitor_subprocess(self): self._send_heartbeat_if_needed() self._handle_process_overtime_if_needed() + self._handle_execution_timeout_if_needed() + + def _handle_execution_timeout_if_needed(self): + """Handle task execution timeout by sending SIGTERM, then SIGKILL if needed.""" + # Only check timeout if we have received timeout configuration and task is still running + if self._execution_timeout_seconds is None or self._terminal_state: + return + + if self._task_execution_start_time is None: + return + + elapsed_time = time.monotonic() - self._task_execution_start_time + + # Check if we've exceeded the execution timeout + if elapsed_time > self._execution_timeout_seconds: + # Grace period for SIGKILL after SIGTERM (5 seconds) + SIGKILL_GRACE_PERIOD = 5.0 + + if self._timeout_sigterm_sent_at is None: + # First timeout - send SIGTERM + log.warning( + "Task execution timeout exceeded; sending SIGTERM", + ti_id=self.id, + timeout_seconds=self._execution_timeout_seconds, + elapsed_seconds=elapsed_time, + ) + self.kill(signal.SIGTERM) + self._timeout_sigterm_sent_at = time.monotonic() + elif (time.monotonic() - self._timeout_sigterm_sent_at) > SIGKILL_GRACE_PERIOD: + # SIGTERM didn't work, escalate to SIGKILL + log.error( + "Task did not respond to SIGTERM; sending SIGKILL", + ti_id=self.id, + grace_period_seconds=SIGKILL_GRACE_PERIOD, + ) + self.kill(signal.SIGKILL, force=True) def _handle_process_overtime_if_needed(self): """Handle termination of auxiliary processes if the task exceeds the configured overtime.""" @@ -1234,6 +1281,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() self._rendered_map_index = msg.rendered_map_index + elif isinstance(msg, TaskExecutionTimeout): + # Task has sent us its execution timeout configuration + self._execution_timeout_seconds = msg.timeout_seconds + self._task_execution_start_time = time.monotonic() + log.info( + "Received task execution timeout from task process", + timeout_seconds=msg.timeout_seconds, + ) elif isinstance(msg, SucceedTask): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 95665c43f3b7a..d770ac4921cf1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -27,7 +27,7 @@ import time from collections.abc import Callable, Iterable, Iterator, Mapping from contextlib import suppress -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from itertools import product from pathlib import Path from typing import TYPE_CHECKING, Annotated, Any, Literal @@ -95,6 +95,7 @@ StartupDetails, SucceedTask, TaskBreadcrumbsResult, + TaskExecutionTimeout, TaskRescheduleStartDate, TaskState, TaskStatesResult, @@ -830,6 +831,14 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: ti = parse(msg, log) log.debug("Dag file parsed", file=msg.dag_rel_path) + # Send execution timeout to supervisor if configured + timeout = ti.task.execution_timeout + if isinstance(timeout, timedelta): + timeout_seconds = timeout.total_seconds() + if timeout_seconds > 0: + SUPERVISOR_COMMS.send(TaskExecutionTimeout(timeout_seconds=timeout_seconds)) + log.debug("Sent execution timeout to supervisor", timeout_seconds=timeout_seconds) + run_as_user = getattr(ti.task, "run_as_user", None) or conf.get( "core", "default_impersonation", fallback=None ) @@ -1520,23 +1529,9 @@ def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): _run_task_state_change_callbacks(task, "on_execute_callback", context, log) - if task.execution_timeout: - from airflow.sdk.execution_time.timeout import timeout - - # TODO: handle timeout in case of deferral - timeout_seconds = task.execution_timeout.total_seconds() - try: - # It's possible we're already timed out, so fast-fail if true - if timeout_seconds <= 0: - raise AirflowTaskTimeout() - # Run task in timeout wrapper - with timeout(timeout_seconds): - result = ctx.run(execute, context=context) - except AirflowTaskTimeout: - task.on_kill() - raise - else: - result = ctx.run(execute, context=context) + # Timeout is now handled at supervisor level, not in the task process + # The supervisor monitors execution time and sends SIGTERM/SIGKILL if exceeded + result = ctx.run(execute, context=context) if (post_execute_hook := task._post_execute_hook) is not None: create_executable_runner(post_execute_hook, outlet_events, logger=log).run(context, result) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index b45f57cd4256e..560442df07441 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -117,6 +117,7 @@ SkipDownstreamTasks, SucceedTask, TaskBreadcrumbsResult, + TaskExecutionTimeout, TaskRescheduleStartDate, TaskState, TaskStatesResult, @@ -2461,6 +2462,12 @@ class RequestTestCase: }, test_id="get_task_breadcrumbs", ), + RequestTestCase( + message=TaskExecutionTimeout(timeout_seconds=300.0), + test_id="task_execution_timeout", + client_mock=None, + expected_body=None, + ), ] diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index f743abdd9a61b..e32f78c49bd6e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -67,7 +67,6 @@ AirflowSensorTimeout, AirflowSkipException, AirflowTaskTerminated, - AirflowTaskTimeout, DownstreamTasksSkipped, ErrorType, TaskDeferred, @@ -656,47 +655,6 @@ def test_run_raises_airflow_exception(time_machine, create_runtime_ti, mock_supe mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) -def test_run_task_timeout(time_machine, create_runtime_ti, mock_supervisor_comms): - """Test running a basic task that times out.""" - from time import sleep - - task = PythonOperator( - task_id="sleep", - execution_timeout=timedelta(milliseconds=10), - python_callable=lambda: sleep(2), - ) - - ti = create_runtime_ti(task=task, dag_id="basic_dag_time_out") - - instant = timezone.datetime(2024, 12, 3, 10, 0) - time_machine.move_to(instant, tick=False) - - run(ti, context=ti.get_template_context(), log=mock.MagicMock()) - - assert ti.state == TaskInstanceState.FAILED - - # this state can only be reached if the try block passed down the exception to handler of AirflowTaskTimeout - mock_supervisor_comms.send.assert_called_with(TaskState(state=TaskInstanceState.FAILED, end_date=instant)) - - -def test_execution_timeout(create_runtime_ti): - def sleep_and_catch_other_exceptions(): - with contextlib.suppress(Exception): - # Catching Exception should NOT catch AirflowTaskTimeout - time.sleep(5) - - op = PythonOperator( - task_id="test_timeout", - execution_timeout=timedelta(seconds=1), - python_callable=sleep_and_catch_other_exceptions, - ) - - ti = create_runtime_ti(task=op, dag_id="dag_execution_timeout") - - with pytest.raises(AirflowTaskTimeout): - _execute_task(context=ti.get_template_context(), ti=ti, log=mock.MagicMock()) - - def test_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms, spy_agency): """Test running a Dag with templated task.""" from airflow.providers.standard.operators.bash import BashOperator @@ -3753,7 +3711,6 @@ def execute(self, context): def test_task_runner_both_callbacks_have_timing_info(self, create_runtime_ti): """Test that both success and failure callbacks receive accurate timing information.""" - import time success_data = {} failure_data = {} diff --git a/task-sdk/tests/task_sdk/execution_time/test_timeout.py b/task-sdk/tests/task_sdk/execution_time/test_timeout.py new file mode 100644 index 0000000000000..d99f93fdb29ac --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_timeout.py @@ -0,0 +1,379 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import signal +import time +from unittest.mock import MagicMock, patch + +import pytest +from task_sdk import FAKE_BUNDLE + +from airflow.sdk.execution_time.comms import TaskExecutionTimeout +from airflow.sdk.execution_time.supervisor import ActivitySubprocess + + +@pytest.fixture +def client_with_ti_start(make_ti_context): + """Create a mock client with task instance start configured.""" + from unittest.mock import MagicMock + + from airflow.sdk.api import client as sdk_client + + client = MagicMock(spec=sdk_client.Client) + client.task_instances.start.return_value = make_ti_context() + return client + + +class TestExecutionTimeout: + """Test cases for execution timeout handling in supervisor.""" + + @pytest.fixture + def mock_subprocess(self): + """Create a mock ActivitySubprocess for testing.""" + subprocess = MagicMock(spec=ActivitySubprocess) + subprocess._execution_timeout_seconds = None + subprocess._task_execution_start_time = None + subprocess._timeout_sigterm_sent_at = None + subprocess._terminal_state = None + subprocess.kill = MagicMock() + return subprocess + + def test_handle_timeout_message(self, mock_subprocess, monkeypatch): + """Test that supervisor receives and stores timeout configuration.""" + # Bind the method to mock instance + monkeypatch.setattr(ActivitySubprocess, "_handle_execution_timeout_if_needed", lambda self: None) + + # Create timeout message + timeout_msg = TaskExecutionTimeout(timeout_seconds=10.0) + + # Simulate receiving the message + with patch("time.monotonic", return_value=100.0): + mock_subprocess._execution_timeout_seconds = timeout_msg.timeout_seconds + mock_subprocess._task_execution_start_time = time.monotonic() + + assert mock_subprocess._execution_timeout_seconds == 10.0 + assert mock_subprocess._task_execution_start_time == 100.0 + + def test_timeout_not_checked_without_config(self, mock_subprocess): + """Test that timeout is not checked if no timeout was configured.""" + # Call the real method + ActivitySubprocess._handle_execution_timeout_if_needed(mock_subprocess) + + # Should not call kill if no timeout configured + mock_subprocess.kill.assert_not_called() + + def test_timeout_not_checked_for_terminal_state(self, mock_subprocess): + """Test that timeout is not checked if task reached terminal state.""" + mock_subprocess._execution_timeout_seconds = 10.0 + mock_subprocess._task_execution_start_time = 100.0 + mock_subprocess._terminal_state = "success" + + ActivitySubprocess._handle_execution_timeout_if_needed(mock_subprocess) + + mock_subprocess.kill.assert_not_called() + + def test_timeout_sends_sigterm(self, mock_subprocess): + """Test that SIGTERM is sent when timeout is exceeded.""" + mock_subprocess._execution_timeout_seconds = 5.0 + mock_subprocess._task_execution_start_time = 100.0 + mock_subprocess._timeout_sigterm_sent_at = None + + with patch("time.monotonic", return_value=106.0): # 6 seconds elapsed + ActivitySubprocess._handle_execution_timeout_if_needed(mock_subprocess) + + # Should send SIGTERM + mock_subprocess.kill.assert_called_once_with(signal.SIGTERM) + assert mock_subprocess._timeout_sigterm_sent_at == 106.0 + + def test_timeout_escalates_to_sigkill(self, mock_subprocess): + """Test that SIGKILL is sent if SIGTERM doesn't work.""" + mock_subprocess._execution_timeout_seconds = 5.0 + mock_subprocess._task_execution_start_time = 100.0 + mock_subprocess._timeout_sigterm_sent_at = 106.0 + + # 12 seconds after SIGTERM (beyond 5 second grace period) + with patch("time.monotonic", return_value=118.0): + ActivitySubprocess._handle_execution_timeout_if_needed(mock_subprocess) + + # Should send SIGKILL with force=True + mock_subprocess.kill.assert_called_once_with(signal.SIGKILL, force=True) + + def test_timeout_within_grace_period(self, mock_subprocess): + """Test that SIGKILL is not sent during grace period.""" + mock_subprocess._execution_timeout_seconds = 5.0 + mock_subprocess._task_execution_start_time = 100.0 + mock_subprocess._timeout_sigterm_sent_at = 106.0 + + # 3 seconds after SIGTERM (within 5 second grace period) + with patch("time.monotonic", return_value=109.0): + ActivitySubprocess._handle_execution_timeout_if_needed(mock_subprocess) + + # Should not send SIGKILL yet + mock_subprocess.kill.assert_not_called() + + def test_timeout_message_serialization(self): + """Test that TaskExecutionTimeout message serializes correctly.""" + timeout_msg = TaskExecutionTimeout(timeout_seconds=15.5) + + # Serialize + json_str = timeout_msg.model_dump_json() + assert "15.5" in json_str + assert "TaskExecutionTimeout" in json_str + + # Deserialize + import json + + data = json.loads(json_str) + assert data["timeout_seconds"] == 15.5 + assert data["type"] == "TaskExecutionTimeout" + + def test_zero_timeout_immediately_triggers(self, mock_subprocess): + """Test that zero or negative timeout triggers immediately.""" + mock_subprocess._execution_timeout_seconds = 0.001 + mock_subprocess._task_execution_start_time = 100.0 + mock_subprocess._timeout_sigterm_sent_at = None + + with patch("time.monotonic", return_value=100.002): # Slightly past timeout + ActivitySubprocess._handle_execution_timeout_if_needed(mock_subprocess) + + # Should immediately send SIGTERM for zero timeout + mock_subprocess.kill.assert_called_once_with(signal.SIGTERM) + + def test_timeout_uses_monotonic_clock(self, mock_subprocess): + """Test that timeout uses monotonic clock for accuracy.""" + mock_subprocess._execution_timeout_seconds = 10.0 + + # Use different monotonic values + with patch("time.monotonic", return_value=1000.0): + mock_subprocess._task_execution_start_time = time.monotonic() + + # Verify monotonic clock is used + assert mock_subprocess._task_execution_start_time == 1000.0 + + # Simulate timeout + mock_subprocess._timeout_sigterm_sent_at = None + with patch("time.monotonic", return_value=1011.0): # 11 seconds later + ActivitySubprocess._handle_execution_timeout_if_needed(mock_subprocess) + + mock_subprocess.kill.assert_called_once_with(signal.SIGTERM) + + +class TestExecutionTimeoutIntegration: + """Integration tests for execution timeout with real subprocesses.""" + + @pytest.fixture(autouse=True) + def disable_log_upload(self, spy_agency): + """Disable log upload for tests.""" + spy_agency.spy_on(ActivitySubprocess._upload_logs, call_original=False) + + def test_timeout_kills_long_running_task(self, monkeypatch, client_with_ti_start, captured_logs): + """Test that a task exceeding timeout is killed with SIGTERM.""" + import os + import sys + import time + + from uuid6 import uuid7 + + from airflow.sdk.api.datamodels._generated import TaskInstance + from airflow.sdk.execution_time.comms import CommsDecoder, TaskExecutionTimeout + from airflow.sdk.execution_time.supervisor import ActivitySubprocess + + def subprocess_main(): + # Get startup message + comms = CommsDecoder() + comms._get_response() + + # Send timeout configuration (2 seconds) + comms.send(TaskExecutionTimeout(timeout_seconds=2.0)) + + # Sleep for 10 seconds (should be killed after 2) + print("Task started, sleeping for 10 seconds...") + sys.stdout.flush() + time.sleep(10) + print("This should never print - task should be killed") + + proc = ActivitySubprocess.start( + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, + what=TaskInstance( + id=uuid7(), + task_id="timeout_task", + dag_id="test_dag", + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + ), + client=client_with_ti_start, + target=subprocess_main, + ) + + start_time = time.time() + rc = proc.wait() + elapsed = time.time() - start_time + + # Task should be killed, so non-zero exit code + assert rc != 0, "Task should have been killed" + + # Should complete in roughly 5-7 seconds due to heartbeat interval + # (1s timeout + up to 5s heartbeat check delay) + assert elapsed < 8, f"Task took {elapsed}s, expected ~6s (timeout + heartbeat delay) not full 10s" + + # Check logs for timeout message + log_events = [log.get("event", "") for log in captured_logs] + assert any("timeout" in str(event).lower() for event in log_events), ( + "Expected timeout message in logs" + ) + + def test_timeout_escalates_to_sigkill(self, monkeypatch, client_with_ti_start): + """Test that supervisor sends SIGKILL if SIGTERM doesn't work.""" + import os + import signal + import time + + from uuid6 import uuid7 + + from airflow.sdk.api.datamodels._generated import TaskInstance + from airflow.sdk.execution_time.comms import CommsDecoder, TaskExecutionTimeout + from airflow.sdk.execution_time.supervisor import ActivitySubprocess + + def subprocess_main(): + # Get startup message + comms = CommsDecoder() + comms._get_response() + + # Send timeout configuration (1 second) + comms.send(TaskExecutionTimeout(timeout_seconds=1.0)) + + # Ignore SIGTERM to test SIGKILL escalation + signal.signal(signal.SIGTERM, signal.SIG_IGN) + + print("Task started, ignoring SIGTERM...") + time.sleep(20) # Sleep long enough for SIGKILL + print("This should never print") + + proc = ActivitySubprocess.start( + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, + what=TaskInstance( + id=uuid7(), + task_id="sigkill_task", + dag_id="test_dag", + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + ), + client=client_with_ti_start, + target=subprocess_main, + ) + + start_time = time.time() + rc = proc.wait() + elapsed = time.time() - start_time + + # Should be killed with SIGKILL + assert rc != 0 + + # Should complete in ~15-17 seconds + # (1s timeout + 5s heartbeat delay + 5s SIGTERM grace + 5s SIGKILL check delay) + assert elapsed < 18, f"Task took {elapsed}s, expected ~15s for SIGKILL escalation" + + def test_no_timeout_task_completes_normally(self, monkeypatch, client_with_ti_start): + """Test that tasks without timeout complete normally.""" + import os + import time + + from uuid6 import uuid7 + + from airflow.sdk.api.datamodels._generated import TaskInstance + from airflow.sdk.execution_time.comms import CommsDecoder + from airflow.sdk.execution_time.supervisor import ActivitySubprocess + + def subprocess_main(): + # Get startup message + comms = CommsDecoder() + comms._get_response() + + # Don't send timeout message + print("Task without timeout") + time.sleep(0.5) + print("Task completed successfully") + + proc = ActivitySubprocess.start( + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, + what=TaskInstance( + id=uuid7(), + task_id="no_timeout_task", + dag_id="test_dag", + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + ), + client=client_with_ti_start, + target=subprocess_main, + ) + + rc = proc.wait() + + # Should complete successfully + assert rc == 0, "Task without timeout should complete successfully" + + def test_task_completes_before_timeout(self, monkeypatch, client_with_ti_start): + """Test that tasks completing before timeout are not killed.""" + import os + import time + + from uuid6 import uuid7 + + from airflow.sdk.api.datamodels._generated import TaskInstance + from airflow.sdk.execution_time.comms import CommsDecoder, TaskExecutionTimeout + from airflow.sdk.execution_time.supervisor import ActivitySubprocess + + def subprocess_main(): + # Get startup message + comms = CommsDecoder() + comms._get_response() + + # Send timeout configuration (5 seconds) + comms.send(TaskExecutionTimeout(timeout_seconds=5.0)) + + # Complete quickly (1 second) + print("Task started") + time.sleep(1) + print("Task completed within timeout") + + proc = ActivitySubprocess.start( + dag_rel_path=os.devnull, + bundle_info=FAKE_BUNDLE, + what=TaskInstance( + id=uuid7(), + task_id="fast_task", + dag_id="test_dag", + run_id="test_run", + try_number=1, + dag_version_id=uuid7(), + ), + client=client_with_ti_start, + target=subprocess_main, + ) + + rc = proc.wait() + + # Should complete successfully + assert rc == 0, "Task completing before timeout should succeed"