From 49e5e178b979a60d503c5b8035042e2f25e5de78 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 17 Sep 2025 16:38:31 +0530 Subject: [PATCH 1/4] Respect task retries for signal killed tasks --- .../airflow/sdk/execution_time/supervisor.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 797400c85083d..a4c9d5465c841 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -910,6 +910,9 @@ class ActivitySubprocess(WatchedSubprocess): _last_successful_heartbeat: float = attrs.field(default=0, init=False) _last_heartbeat_attempt: float = attrs.field(default=0, init=False) + _should_retry: bool = attrs.field(default=False, init=False) + """Whether the task should retry or not as decided by the API server.""" + # After the failure of a heartbeat, we'll increment this counter. If it reaches `MAX_FAILED_HEARTBEATS`, we # will kill theprocess. This is to handle temporary network issues etc. ensuring that the process # does not hang around forever. @@ -949,6 +952,7 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st # message. But before we do that, we need to tell the server it's started (so it has the chance to # tell us "no, stop!" for any reason) ti_context = self.client.task_instances.start(ti.id, self.pid, start_date) + self._should_retry = ti_context.should_retry self._last_successful_heartbeat = time.monotonic() except Exception: # On any error kill that subprocess! @@ -1147,8 +1151,22 @@ def final_state(self): return self._terminal_state or TaskInstanceState.SUCCESS if self._exit_code != 0 and self._terminal_state == SERVER_TERMINATED: return SERVER_TERMINATED + + if self._exit_code != 0 and self._is_signal_retryable() and self._should_retry: + return TaskInstanceState.UP_FOR_RETRY + return TaskInstanceState.FAILED + def _is_signal_retryable(self) -> bool: + """Check if the exit code signal can be retried.""" + if self._exit_code is None: + return False + + if self._exit_code == -signal.SIGKILL or self._exit_code == -signal.SIGTERM: + return True + + return False + def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int): if isinstance(msg, MaskSecret): log.debug("Received message from task runner (body omitted)", msg=type(msg)) From 7eca5a6bc773485bdb6ee27969d3512cf00131e0 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 17 Sep 2025 16:55:39 +0530 Subject: [PATCH 2/4] adding unit tests --- .../execution_time/test_supervisor.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 94ad8d919e7e1..938dba307ce41 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2425,3 +2425,74 @@ def mock_upload_to_remote(process_log, ti): f"Connection {expected_env} was not available during upload_to_remote call" ) assert connection_available["conn_uri"] is not None, "Connection URI was None during upload" + + +class TestSignalRetryLogic: + """Test signal based retry logic in ActivitySubprocess.""" + + @pytest.mark.parametrize( + "signal, expected_state", + [ + (signal.SIGKILL, TaskInstanceState.UP_FOR_RETRY), + (signal.SIGTERM, TaskInstanceState.UP_FOR_RETRY), + # SIGABRT and SIGSEGV should not retry + (signal.SIGABRT, TaskInstanceState.FAILED), + (signal.SIGSEGV, TaskInstanceState.FAILED), + ], + ) + def test_signals_with_retry(self, mocker, signal, expected_state): + """Test that signals with task retries.""" + mock_watched_subprocess = ActivitySubprocess( + process_log=mocker.MagicMock(), + id=TI_ID, + pid=12345, + stdin=mocker.Mock(), + process=mocker.Mock(), + client=mocker.Mock(), + ) + + mock_watched_subprocess._exit_code = -signal + mock_watched_subprocess._should_retry = True + + result = mock_watched_subprocess.final_state + assert result == expected_state + + @pytest.mark.parametrize( + "signal", + [ + signal.SIGKILL, + signal.SIGTERM, + signal.SIGABRT, + signal.SIGSEGV, + ], + ) + def test_signals_without_retry_always_fail(self, mocker, signal): + """Test that signals without task retries enabled always fail.""" + mock_watched_subprocess = ActivitySubprocess( + process_log=mocker.MagicMock(), + id=TI_ID, + pid=12345, + stdin=mocker.Mock(), + process=mocker.Mock(), + client=mocker.Mock(), + ) + mock_watched_subprocess._should_retry = False + mock_watched_subprocess._exit_code = -signal + + result = mock_watched_subprocess.final_state + assert result == TaskInstanceState.FAILED + + def test_non_signal_exit_code_goes_to_failed(self, mocker): + """Test that non signal exit codes go to failed regardless of task retries.""" + mock_watched_subprocess = ActivitySubprocess( + process_log=mocker.MagicMock(), + id=TI_ID, + pid=12345, + stdin=mocker.Mock(), + process=mocker.Mock(), + client=mocker.Mock(), + ) + mock_watched_subprocess._exit_code = 1 + mock_watched_subprocess._should_retry = True + + assert mock_watched_subprocess.final_state == TaskInstanceState.FAILED From d8390a86fbb43e676359a5996e4dd8f695a5c363 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 24 Sep 2025 14:10:20 +0530 Subject: [PATCH 3/4] handling comments from ash --- task-sdk/src/airflow/sdk/execution_time/supervisor.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index a4c9d5465c841..06af5a1d838db 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1152,20 +1152,14 @@ def final_state(self): if self._exit_code != 0 and self._terminal_state == SERVER_TERMINATED: return SERVER_TERMINATED - if self._exit_code != 0 and self._is_signal_retryable() and self._should_retry: + if self._is_signal_retryable() and self._should_retry: return TaskInstanceState.UP_FOR_RETRY return TaskInstanceState.FAILED def _is_signal_retryable(self) -> bool: """Check if the exit code signal can be retried.""" - if self._exit_code is None: - return False - - if self._exit_code == -signal.SIGKILL or self._exit_code == -signal.SIGTERM: - return True - - return False + return self._exit_code in (-signal.SIGKILL, -signal.SIGTERM, -signal.SIGSEGV) def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int): if isinstance(msg, MaskSecret): From ac514f2e37157f8c1c505bab4d3c66c56c579457 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Mon, 29 Sep 2025 13:31:10 +0530 Subject: [PATCH 4/4] retry any signal --- .../src/airflow/sdk/execution_time/supervisor.py | 9 ++++----- .../task_sdk/execution_time/test_supervisor.py | 14 +++++++------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 6d7525c2d5394..2df0a00d6dbb8 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1156,15 +1156,14 @@ def final_state(self): if self._exit_code != 0 and self._terminal_state == SERVER_TERMINATED: return SERVER_TERMINATED - if self._is_signal_retryable() and self._should_retry: + # Any negative exit code indicates a signal kill + # We consider all signal kills as potentially retryable + # since they're often transient issues that could succeed on retry + if self._exit_code < 0 and self._should_retry: return TaskInstanceState.UP_FOR_RETRY return TaskInstanceState.FAILED - def _is_signal_retryable(self) -> bool: - """Check if the exit code signal can be retried.""" - return self._exit_code in (-signal.SIGKILL, -signal.SIGTERM, -signal.SIGSEGV) - def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int): if isinstance(msg, MaskSecret): log.debug("Received message from task runner (body omitted)", msg=type(msg)) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index aba491f2f776c..b29e13edc3706 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2433,15 +2433,15 @@ class TestSignalRetryLogic: """Test signal based retry logic in ActivitySubprocess.""" @pytest.mark.parametrize( - "signal, expected_state", + "signal", [ - (signal.SIGKILL, TaskInstanceState.UP_FOR_RETRY), - (signal.SIGTERM, TaskInstanceState.UP_FOR_RETRY), - (signal.SIGSEGV, TaskInstanceState.UP_FOR_RETRY), - (signal.SIGABRT, TaskInstanceState.FAILED), + signal.SIGTERM, + signal.SIGKILL, + signal.SIGABRT, + signal.SIGSEGV, ], ) - def test_signals_with_retry(self, mocker, signal, expected_state): + def test_signals_with_retry(self, mocker, signal): """Test that signals with task retries.""" mock_watched_subprocess = ActivitySubprocess( process_log=mocker.MagicMock(), @@ -2456,7 +2456,7 @@ def test_signals_with_retry(self, mocker, signal, expected_state): mock_watched_subprocess._should_retry = True result = mock_watched_subprocess.final_state - assert result == expected_state + assert result == TaskInstanceState.UP_FOR_RETRY @pytest.mark.parametrize( "signal",