diff --git a/src/prefect/engine.py b/src/prefect/engine.py index 79edb51f6433..328af782df31 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -16,6 +16,8 @@ See `orchestrate_flow_run`, `orchestrate_task_run` """ import logging +import os +import signal import sys from contextlib import AsyncExitStack, asynccontextmanager, nullcontext from functools import partial @@ -47,6 +49,7 @@ NotPausedError, Pause, PausedRun, + TerminationSignal, UpstreamTaskError, ) from prefect.flows import Flow @@ -1648,6 +1651,17 @@ async def report_flow_run_crashes(flow_run: FlowRun, client: OrionClient): This context _must_ reraise the exception to properly exit the run. """ + + def cancel_flow_run(*args): + raise TerminationSignal(signal=signal.SIGTERM) + + original_term_handler = None + try: + original_term_handler = signal.signal(signal.SIGTERM, cancel_flow_run) + except ValueError: + # Signals only work in the main thread + pass + try: yield except (Abort, Pause): @@ -1662,14 +1676,24 @@ async def report_flow_run_crashes(flow_run: FlowRun, client: OrionClient): await client.set_flow_run_state( state=state, flow_run_id=flow_run.id, - force=True, ) engine_logger.debug( f"Reported crashed flow run {flow_run.name!r} successfully!" ) + if isinstance(exc, TerminationSignal): + # Termination signals are swapped out during a flow run to perform + # a graceful shutdown and raise this exception. This `os.kill` call + # ensures that the previous handler, likely the Python default, + # gets called as well. + signal.signal(exc.signal, original_term_handler) + os.kill(os.getpid(), exc.signal) + # Reraise the exception raise exc from None + finally: + if original_term_handler is not None: + signal.signal(signal.SIGTERM, original_term_handler) @asynccontextmanager diff --git a/src/prefect/exceptions.py b/src/prefect/exceptions.py index afc6de7a56d0..74f359693847 100644 --- a/src/prefect/exceptions.py +++ b/src/prefect/exceptions.py @@ -276,6 +276,21 @@ class Pause(PrefectSignal): """ +class ExternalSignal(BaseException): + """ + Base type for external signal-like exceptions that should never be caught by users. + """ + + +class TerminationSignal(ExternalSignal): + """ + Raised when a flow run receives a termination signal. + """ + + def __init__(self, signal: int): + self.signal = signal + + class PrefectHTTPStatusError(HTTPStatusError): """ Raised when client receives a `Response` that contains an HTTPStatusError. diff --git a/src/prefect/orion/orchestration/core_policy.py b/src/prefect/orion/orchestration/core_policy.py index 2fce16f0b32e..31410a85446f 100644 --- a/src/prefect/orion/orchestration/core_policy.py +++ b/src/prefect/orion/orchestration/core_policy.py @@ -41,6 +41,7 @@ class CoreFlowPolicy(BaseOrchestrationPolicy): def priority(): return [ HandleFlowTerminalStateTransitions, + HandleCancellingStateTransitions, PreventRedundantTransitions, HandlePausingFlows, HandleResumingPausedFlows, @@ -786,3 +787,27 @@ async def before_transition( await self.abort_transition( reason=f"The enclosing flow must be running to begin task execution.", ) + + +class HandleCancellingStateTransitions(BaseOrchestrationRule): + """ + Rejects transitions from Cancelling to any terminal state except for Cancelled. + """ + + FROM_STATES = {StateType.CANCELLED, StateType.CANCELLING} + TO_STATES = ALL_ORCHESTRATION_STATES - {StateType.CANCELLED} + + async def before_transition( + self, + initial_state: Optional[states.State], + proposed_state: Optional[states.State], + context: TaskOrchestrationContext, + ) -> None: + await self.reject_transition( + state=None, + reason=( + "Cannot transition flows that are cancelling to a state other " + "than Cancelled." + ), + ) + return diff --git a/tests/orion/orchestration/test_core_policy.py b/tests/orion/orchestration/test_core_policy.py index 33db69cb1db6..f2f4fe231511 100644 --- a/tests/orion/orchestration/test_core_policy.py +++ b/tests/orion/orchestration/test_core_policy.py @@ -15,6 +15,7 @@ CacheInsertion, CacheRetrieval, CopyScheduledTime, + HandleCancellingStateTransitions, HandleFlowTerminalStateTransitions, HandlePausingFlows, HandleResumingPausedFlows, @@ -2541,3 +2542,59 @@ async def test_prevents_tasks_From_running( else: assert ctx.response_status == SetStateStatus.ABORT assert ctx.validated_state.is_pending() + + +class TestHandleCancellingStateTransitions: + @pytest.mark.parametrize( + "proposed_state_type", + sorted( + list(set(ALL_ORCHESTRATION_STATES) - {states.StateType.CANCELLED, None}) + ), + ) + async def test_rejects_cancelling_to_anything_but_cancelled( + self, + session, + initialize_orchestration, + proposed_state_type, + ): + initial_state_type = states.StateType.CANCELLING + intended_transition = (initial_state_type, proposed_state_type) + + ctx = await initialize_orchestration( + session, + "flow", + *intended_transition, + ) + + async with HandleCancellingStateTransitions(ctx, *intended_transition) as ctx: + await ctx.validate_proposed_state() + + assert ctx.response_status == SetStateStatus.REJECT + assert ctx.validated_state_type == states.StateType.CANCELLING + + @pytest.mark.parametrize( + "proposed_state_type", + sorted( + list(set(ALL_ORCHESTRATION_STATES) - {states.StateType.CANCELLED, None}) + ), + ) + async def test_rejects_cancelled_cancelling_to_anything_but_cancelled( + self, + session, + initialize_orchestration, + proposed_state_type, + ): + initial_state_type = states.StateType.CANCELLING + intended_transition = (initial_state_type, proposed_state_type) + + ctx = await initialize_orchestration( + session, + "flow", + *intended_transition, + ) + + async with HandleCancellingStateTransitions(ctx, *intended_transition) as ctx: + await ctx.validate_proposed_state() + + assert ctx.response_status == SetStateStatus.REJECT + assert ctx.validated_state_type == states.StateType.CANCELLING diff --git a/tests/test_engine.py b/tests/test_engine.py index b791bfd8f3c4..30060ed7f73b 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,10 +1,12 @@ import asyncio +import os +import signal import statistics import time from contextlib import contextmanager from functools import partial from typing import List -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock from uuid import uuid4 import anyio @@ -23,6 +25,7 @@ orchestrate_flow_run, orchestrate_task_run, pause_flow_run, + report_flow_run_crashes, resume_flow_run, retrieve_flow_then_begin_flow_run, ) @@ -34,6 +37,7 @@ Pause, PausedRun, SignatureMismatchError, + TerminationSignal, ) from prefect.futures import PrefectFuture from prefect.orion.schemas.actions import FlowRunCreate @@ -2597,6 +2601,20 @@ def my_flow(): assert i <= 10, "`just_sleep` should not be running after timeout" + async def test_report_flow_run_crashes_handles_sigterm( + self, flow_run, orion_client, monkeypatch + ): + original_handler = Mock() + signal.signal(signal.SIGTERM, original_handler) + + with pytest.raises(TerminationSignal): + async with report_flow_run_crashes(flow_run=flow_run, client=orion_client): + assert signal.getsignal(signal.SIGTERM) != original_handler + os.kill(os.getpid(), signal.SIGTERM) + + original_handler.assert_called_once() + assert signal.getsignal(signal.SIGTERM) == original_handler + class TestTaskRunCrashes: @pytest.mark.parametrize("interrupt_type", [KeyboardInterrupt, SystemExit])