Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,9 @@ class ActivitySubprocess(WatchedSubprocess):
_last_successful_heartbeat: float = attrs.field(default=0, init=False)
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)

_should_retry: bool = attrs.field(default=False, init=False)
"""Whether the task should retry or not as decided by the API server."""

# After the failure of a heartbeat, we'll increment this counter. If it reaches `MAX_FAILED_HEARTBEATS`, we
# will kill theprocess. This is to handle temporary network issues etc. ensuring that the process
# does not hang around forever.
Expand Down Expand Up @@ -964,6 +967,7 @@ def _on_child_started(self, ti: TaskInstance, dag_rel_path: str | os.PathLike[st
# message. But before we do that, we need to tell the server it's started (so it has the chance to
# tell us "no, stop!" for any reason)
ti_context = self.client.task_instances.start(ti.id, self.pid, start_date)
self._should_retry = ti_context.should_retry
self._last_successful_heartbeat = time.monotonic()
except Exception:
# On any error kill that subprocess!
Expand Down Expand Up @@ -1162,6 +1166,13 @@ def final_state(self):
return self._terminal_state or TaskInstanceState.SUCCESS
if self._exit_code != 0 and self._terminal_state == SERVER_TERMINATED:
return SERVER_TERMINATED

# Any negative exit code indicates a signal kill
# We consider all signal kills as potentially retryable
# since they're often transient issues that could succeed on retry
if self._exit_code < 0 and self._should_retry:
return TaskInstanceState.UP_FOR_RETRY

return TaskInstanceState.FAILED

def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: int):
Expand Down
70 changes: 70 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2640,6 +2640,76 @@ def mock_upload_to_remote(process_log, ti):
assert connection_available["conn_uri"] is not None, "Connection URI was None during upload"


class TestSignalRetryLogic:
"""Test signal based retry logic in ActivitySubprocess."""

@pytest.mark.parametrize(
"signal",
[
signal.SIGTERM,
signal.SIGKILL,
signal.SIGABRT,
signal.SIGSEGV,
],
)
def test_signals_with_retry(self, mocker, signal):
"""Test that signals with task retries."""
mock_watched_subprocess = ActivitySubprocess(
process_log=mocker.MagicMock(),
id=TI_ID,
pid=12345,
stdin=mocker.Mock(),
process=mocker.Mock(),
client=mocker.Mock(),
)

mock_watched_subprocess._exit_code = -signal
mock_watched_subprocess._should_retry = True

result = mock_watched_subprocess.final_state
assert result == TaskInstanceState.UP_FOR_RETRY

@pytest.mark.parametrize(
"signal",
[
signal.SIGKILL,
signal.SIGTERM,
signal.SIGABRT,
signal.SIGSEGV,
],
)
def test_signals_without_retry_always_fail(self, mocker, signal):
"""Test that signals without task retries enabled always fail."""
mock_watched_subprocess = ActivitySubprocess(
process_log=mocker.MagicMock(),
id=TI_ID,
pid=12345,
stdin=mocker.Mock(),
process=mocker.Mock(),
client=mocker.Mock(),
)
mock_watched_subprocess._should_retry = False
mock_watched_subprocess._exit_code = -signal

result = mock_watched_subprocess.final_state
assert result == TaskInstanceState.FAILED

def test_non_signal_exit_code_goes_to_failed(self, mocker):
"""Test that non signal exit codes go to failed regardless of task retries."""
mock_watched_subprocess = ActivitySubprocess(
process_log=mocker.MagicMock(),
id=TI_ID,
pid=12345,
stdin=mocker.Mock(),
process=mocker.Mock(),
client=mocker.Mock(),
)
mock_watched_subprocess._exit_code = 1
mock_watched_subprocess._should_retry = True

assert mock_watched_subprocess.final_state == TaskInstanceState.FAILED


def test_remote_logging_conn_caches_connection_not_client(monkeypatch):
"""Test that connection caching doesn't retain API client references."""
import gc
Expand Down