diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 73dbc448fb2db..2a20e4eec4e0b 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -925,6 +925,9 @@ class ActivitySubprocess(WatchedSubprocess): _last_successful_heartbeat: float = attrs.field(default=0, init=False) _last_heartbeat_attempt: float = attrs.field(default=0, init=False) + _should_retry: bool = attrs.field(default=False, init=False) + """Whether the task should retry or not as decided by the API server.""" + # After the failure of a heartbeat, we'll increment this counter. If it reaches `MAX_FAILED_HEARTBEATS`, we # will kill theprocess. This is to handle temporary network issues etc. ensuring that the process # does not hang around forever. @@ -964,6 +967,7 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st # message. But before we do that, we need to tell the server it's started (so it has the chance to # tell us "no, stop!" for any reason) ti_context = self.client.task_instances.start(ti.id, self.pid, start_date) + self._should_retry = ti_context.should_retry self._last_successful_heartbeat = time.monotonic() except Exception: # On any error kill that subprocess! @@ -1162,6 +1166,13 @@ def final_state(self): return self._terminal_state or TaskInstanceState.SUCCESS if self._exit_code != 0 and self._terminal_state == SERVER_TERMINATED: return SERVER_TERMINATED + + # Any negative exit code indicates a signal kill + # We consider all signal kills as potentially retryable + # since they're often transient issues that could succeed on retry + if self._exit_code < 0 and self._should_retry: + return TaskInstanceState.UP_FOR_RETRY + return TaskInstanceState.FAILED def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int): diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 48f229543880a..710ad1d7963bc 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2640,6 +2640,76 @@ def mock_upload_to_remote(process_log, ti): assert connection_available["conn_uri"] is not None, "Connection URI was None during upload" +class TestSignalRetryLogic: + """Test signal based retry logic in ActivitySubprocess.""" + + @pytest.mark.parametrize( + "signal", + [ + signal.SIGTERM, + signal.SIGKILL, + signal.SIGABRT, + signal.SIGSEGV, + ], + ) + def test_signals_with_retry(self, mocker, signal): + """Test that signals with task retries.""" + mock_watched_subprocess = ActivitySubprocess( + process_log=mocker.MagicMock(), + id=TI_ID, + pid=12345, + stdin=mocker.Mock(), + process=mocker.Mock(), + client=mocker.Mock(), + ) + + mock_watched_subprocess._exit_code = -signal + mock_watched_subprocess._should_retry = True + + result = mock_watched_subprocess.final_state + assert result == TaskInstanceState.UP_FOR_RETRY + + @pytest.mark.parametrize( + "signal", + [ + signal.SIGKILL, + signal.SIGTERM, + signal.SIGABRT, + signal.SIGSEGV, + ], + ) + def test_signals_without_retry_always_fail(self, mocker, signal): + """Test that signals without task retries enabled always fail.""" + mock_watched_subprocess = ActivitySubprocess( + process_log=mocker.MagicMock(), + id=TI_ID, + pid=12345, + stdin=mocker.Mock(), + process=mocker.Mock(), + client=mocker.Mock(), + ) + mock_watched_subprocess._should_retry = False + mock_watched_subprocess._exit_code = -signal + + result = mock_watched_subprocess.final_state + assert result == TaskInstanceState.FAILED + + def test_non_signal_exit_code_goes_to_failed(self, mocker): + """Test that non signal exit codes go to failed regardless of task retries.""" + mock_watched_subprocess = ActivitySubprocess( + process_log=mocker.MagicMock(), + id=TI_ID, + pid=12345, + stdin=mocker.Mock(), + process=mocker.Mock(), + client=mocker.Mock(), + ) + mock_watched_subprocess._exit_code = 1 + mock_watched_subprocess._should_retry = True + + assert mock_watched_subprocess.final_state == TaskInstanceState.FAILED + + def test_remote_logging_conn_caches_connection_not_client(monkeypatch): """Test that connection caching doesn't retain API client references.""" import gc