Skip to content

Commit

Permalink
Improve engine shutdown handling of SIGTERM (#8127)
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Pickett <chris.pickett@prefect.io>
  • Loading branch information
zanieb and bunchesofdonald committed Jan 25, 2023
1 parent ecaea25 commit 2805642
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 2 deletions.
26 changes: 25 additions & 1 deletion src/prefect/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
See `orchestrate_flow_run`, `orchestrate_task_run`
"""
import logging
import os
import signal
import sys
from contextlib import AsyncExitStack, asynccontextmanager, nullcontext
from functools import partial
Expand Down Expand Up @@ -47,6 +49,7 @@
NotPausedError,
Pause,
PausedRun,
TerminationSignal,
UpstreamTaskError,
)
from prefect.flows import Flow
Expand Down Expand Up @@ -1606,6 +1609,17 @@ async def report_flow_run_crashes(flow_run: FlowRun, client: OrionClient):
This context _must_ reraise the exception to properly exit the run.
"""

def cancel_flow_run(*args):
raise TerminationSignal(signal=signal.SIGTERM)

original_term_handler = None
try:
original_term_handler = signal.signal(signal.SIGTERM, cancel_flow_run)
except ValueError:
# Signals only work in the main thread
pass

try:
yield
except (Abort, Pause):
Expand All @@ -1620,14 +1634,24 @@ async def report_flow_run_crashes(flow_run: FlowRun, client: OrionClient):
await client.set_flow_run_state(
state=state,
flow_run_id=flow_run.id,
force=True,
)
engine_logger.debug(
f"Reported crashed flow run {flow_run.name!r} successfully!"
)

if isinstance(exc, TerminationSignal):
# Termination signals are swapped out during a flow run to perform
# a graceful shutdown and raise this exception. This `os.kill` call
# ensures that the previous handler, likely the Python default,
# gets called as well.
signal.signal(exc.signal, original_term_handler)
os.kill(os.getpid(), exc.signal)

# Reraise the exception
raise exc from None
finally:
if original_term_handler is not None:
signal.signal(signal.SIGTERM, original_term_handler)


@asynccontextmanager
Expand Down
15 changes: 15 additions & 0 deletions src/prefect/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,21 @@ class Pause(PrefectSignal):
"""


class ExternalSignal(BaseException):
"""
Base type for external signal-like exceptions that should never be caught by users.
"""


class TerminationSignal(ExternalSignal):
"""
Raised when a flow run receives a termination signal.
"""

def __init__(self, signal: int):
self.signal = signal


class PrefectHTTPStatusError(HTTPStatusError):
"""
Raised when client receives a `Response` that contains an HTTPStatusError.
Expand Down
25 changes: 25 additions & 0 deletions src/prefect/orion/orchestration/core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class CoreFlowPolicy(BaseOrchestrationPolicy):
def priority():
return [
HandleFlowTerminalStateTransitions,
HandleCancellingStateTransitions,
PreventRedundantTransitions,
HandlePausingFlows,
HandleResumingPausedFlows,
Expand Down Expand Up @@ -786,3 +787,27 @@ async def before_transition(
await self.abort_transition(
reason=f"The enclosing flow must be running to begin task execution.",
)


class HandleCancellingStateTransitions(BaseOrchestrationRule):
"""
Rejects transitions from Cancelling to any terminal state except for Cancelled.
"""

FROM_STATES = {StateType.CANCELLED, StateType.CANCELLING}
TO_STATES = ALL_ORCHESTRATION_STATES - {StateType.CANCELLED}

async def before_transition(
self,
initial_state: Optional[states.State],
proposed_state: Optional[states.State],
context: TaskOrchestrationContext,
) -> None:
await self.reject_transition(
state=None,
reason=(
"Cannot transition flows that are cancelling to a state other "
"than Cancelled."
),
)
return
57 changes: 57 additions & 0 deletions tests/orion/orchestration/test_core_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CacheInsertion,
CacheRetrieval,
CopyScheduledTime,
HandleCancellingStateTransitions,
HandleFlowTerminalStateTransitions,
HandlePausingFlows,
HandleResumingPausedFlows,
Expand Down Expand Up @@ -2541,3 +2542,59 @@ async def test_prevents_tasks_From_running(
else:
assert ctx.response_status == SetStateStatus.ABORT
assert ctx.validated_state.is_pending()


class TestHandleCancellingStateTransitions:
@pytest.mark.parametrize(
"proposed_state_type",
sorted(
list(set(ALL_ORCHESTRATION_STATES) - {states.StateType.CANCELLED, None})
),
)
async def test_rejects_cancelling_to_anything_but_cancelled(
self,
session,
initialize_orchestration,
proposed_state_type,
):
initial_state_type = states.StateType.CANCELLING
intended_transition = (initial_state_type, proposed_state_type)

ctx = await initialize_orchestration(
session,
"flow",
*intended_transition,
)

async with HandleCancellingStateTransitions(ctx, *intended_transition) as ctx:
await ctx.validate_proposed_state()

assert ctx.response_status == SetStateStatus.REJECT
assert ctx.validated_state_type == states.StateType.CANCELLING

@pytest.mark.parametrize(
"proposed_state_type",
sorted(
list(set(ALL_ORCHESTRATION_STATES) - {states.StateType.CANCELLED, None})
),
)
async def test_rejects_cancelled_cancelling_to_anything_but_cancelled(
self,
session,
initialize_orchestration,
proposed_state_type,
):
initial_state_type = states.StateType.CANCELLING
intended_transition = (initial_state_type, proposed_state_type)

ctx = await initialize_orchestration(
session,
"flow",
*intended_transition,
)

async with HandleCancellingStateTransitions(ctx, *intended_transition) as ctx:
await ctx.validate_proposed_state()

assert ctx.response_status == SetStateStatus.REJECT
assert ctx.validated_state_type == states.StateType.CANCELLING
20 changes: 19 additions & 1 deletion tests/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import os
import signal
import statistics
import time
from contextlib import contextmanager
from functools import partial
from typing import List
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock
from uuid import uuid4

import anyio
Expand All @@ -23,6 +25,7 @@
orchestrate_flow_run,
orchestrate_task_run,
pause_flow_run,
report_flow_run_crashes,
resume_flow_run,
retrieve_flow_then_begin_flow_run,
)
Expand All @@ -34,6 +37,7 @@
Pause,
PausedRun,
SignatureMismatchError,
TerminationSignal,
)
from prefect.futures import PrefectFuture
from prefect.orion.schemas.actions import FlowRunCreate
Expand Down Expand Up @@ -1428,6 +1432,20 @@ def my_flow():

assert i <= 10, "`just_sleep` should not be running after timeout"

async def test_report_flow_run_crashes_handles_sigterm(
self, flow_run, orion_client, monkeypatch
):
original_handler = Mock()
signal.signal(signal.SIGTERM, original_handler)

with pytest.raises(TerminationSignal):
async with report_flow_run_crashes(flow_run=flow_run, client=orion_client):
assert signal.getsignal(signal.SIGTERM) != original_handler
os.kill(os.getpid(), signal.SIGTERM)

original_handler.assert_called_once()
assert signal.getsignal(signal.SIGTERM) == original_handler


class TestTaskRunCrashes:
@pytest.mark.parametrize("interrupt_type", [KeyboardInterrupt, SystemExit])
Expand Down

0 comments on commit 2805642

Please sign in to comment.