From 53f2d74ae58023a6488d5d67292375cb57a88f19 Mon Sep 17 00:00:00 2001 From: Chris Pickett Date: Fri, 9 Dec 2022 09:34:59 -0500 Subject: [PATCH 1/4] Shutdown engine gracefully on SIGTERM --- src/prefect/engine.py | 31 ++++++++++++++++++++++++++++--- src/prefect/exceptions.py | 12 ++++++++++++ src/prefect/states.py | 13 ++++++++++--- src/prefect/task_runners.py | 6 +++--- 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/src/prefect/engine.py b/src/prefect/engine.py index 5595cdcec11a..acbd05db5463 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -16,6 +16,7 @@ See `orchestrate_flow_run`, `orchestrate_task_run` """ import logging +import signal import sys from contextlib import AsyncExitStack, asynccontextmanager, nullcontext from functools import partial @@ -41,6 +42,7 @@ from prefect.deployments import load_flow_from_flow_run from prefect.exceptions import ( Abort, + Cancel, FlowPauseTimeout, MappingLengthMismatch, MappingMissingIterable, @@ -72,8 +74,8 @@ Pending, Running, State, - exception_to_crashed_state, exception_to_failed_state, + exception_to_final_state, get_state_exception, return_value_to_state, ) @@ -1566,13 +1568,19 @@ async def report_flow_run_crashes(flow_run: FlowRun, client: OrionClient): This context _must_ reraise the exception to properly exit the run. """ + + def cancel_flow_run(*args): + raise Cancel(cause="SIGTERM") + + original_sigterm_handler = signal.signal(signal.SIGTERM, cancel_flow_run) + try: yield except (Abort, Pause): # Do not capture internal signals as crashes raise except BaseException as exc: - state = await exception_to_crashed_state(exc) + state = await exception_to_final_state(exc) logger = flow_run_logger(flow_run) with anyio.CancelScope(shield=True): logger.error(f"Crash detected! {state.message}") @@ -1588,6 +1596,8 @@ async def report_flow_run_crashes(flow_run: FlowRun, client: OrionClient): # Reraise the exception raise exc from None + finally: + signal.signal(signal.SIGTERM, original_sigterm_handler) @asynccontextmanager @@ -1603,8 +1613,11 @@ async def report_task_run_crashes(task_run: TaskRun, client: OrionClient): except (Abort, Pause): # Do not capture internal signals as crashes raise + except Cancel: + # Do not capture cancellations as crashes + raise except BaseException as exc: - state = await exception_to_crashed_state(exc) + state = await exception_to_final_state(exc) logger = task_run_logger(task_run) with anyio.CancelScope(shield=True): logger.error(f"Crash detected! {state.message}") @@ -1907,6 +1920,18 @@ def should_log_prints(flow_or_task: Union[Flow, Task]) -> bool: f"Engine execution of flow run '{flow_run_id}' is paused: {exc}" ) exit(0) + except Cancel as exc: + engine_logger.info( + f"Engine execution of flow run '{flow_run_id}' cancelled by orchestrator: {exc}" + ) + if exc.cause == "SIGTERM": + # The default SIGTERM handler is swapped out during a flow run to + # raise this `Cancel` exception. This `os.kill` call ensures that + # the previous handler (likely the Python default) gets called as + # well. + os.kill(os.getpid(), signal.SIGTERM) + else: + exit(0) except Exception: engine_logger.error( f"Engine execution of flow run '{flow_run_id}' exited with unexpected " diff --git a/src/prefect/exceptions.py b/src/prefect/exceptions.py index afc6de7a56d0..a6a3ccd8788e 100644 --- a/src/prefect/exceptions.py +++ b/src/prefect/exceptions.py @@ -276,6 +276,18 @@ class Pause(PrefectSignal): """ +class Cancel(PrefectSignal): + """ + Raised when a flow run is cancelled via a SIGTERM + + Indicates that the run should exit immediately. + """ + + def __init__(self, *args, cause="unknown"): + super().__init__(args) + self.cause = cause + + class PrefectHTTPStatusError(HTTPStatusError): """ Raised when client receives a `Response` that contains an HTTPStatusError. diff --git a/src/prefect/states.py b/src/prefect/states.py index 50d90bf0e707..1eafc6f33a10 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -17,6 +17,7 @@ result_from_state_with_data_document, ) from prefect.exceptions import ( + Cancel, CancelledRun, CrashedRun, FailedRun, @@ -124,7 +125,7 @@ def format_exception(exc: BaseException, tb: TracebackType = None) -> str: return formatted -async def exception_to_crashed_state( +async def exception_to_final_state( exc: BaseException, result_factory: Optional[ResultFactory] = None, ) -> State: @@ -133,9 +134,15 @@ async def exception_to_crashed_state( 'Crash' exception with a 'Crashed' state. """ state_message = None + state_cls = Crashed - if isinstance(exc, anyio.get_cancelled_exc_class()): + if isinstance(exc, Cancel): + state_message = "Execution was cancelled by a termination signal." + state_cls = Cancelled + + elif isinstance(exc, anyio.get_cancelled_exc_class()): state_message = "Execution was cancelled by the runtime environment." + state_cls = Cancelled elif isinstance(exc, KeyboardInterrupt): state_message = "Execution was aborted by an interrupt signal." @@ -163,7 +170,7 @@ async def exception_to_crashed_state( # from the API data = exc - return Crashed(message=state_message, data=data) + return state_cls(message=state_message, data=data) async def exception_to_failed_state( diff --git a/src/prefect/task_runners.py b/src/prefect/task_runners.py index 3625925f0fb2..5b7d702b4197 100644 --- a/src/prefect/task_runners.py +++ b/src/prefect/task_runners.py @@ -76,7 +76,7 @@ from prefect.logging import get_logger from prefect.orion.schemas.states import State -from prefect.states import exception_to_crashed_state +from prefect.states import exception_to_final_state from prefect.utilities.collections import AutoEnum T = TypeVar("T", bound="BaseTaskRunner") @@ -206,7 +206,7 @@ async def submit( try: result = await call() except BaseException as exc: - result = await exception_to_crashed_state(exc) + result = await exception_to_final_state(exc) self._results[key] = result @@ -294,7 +294,7 @@ async def _run_and_store_result( try: self._results[key] = await call() except BaseException as exc: - self._results[key] = await exception_to_crashed_state(exc) + self._results[key] = await exception_to_final_state(exc) self._result_events[key].set() From e6426f6c2c744a3dd0d14c9e2f23a82645ceb5c9 Mon Sep 17 00:00:00 2001 From: Michael Adkins Date: Tue, 10 Jan 2023 10:54:50 -0600 Subject: [PATCH 2/4] Add tests --- tests/test_engine.py | 26 +++++++++++++++++++++++++- tests/test_states.py | 12 +++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/tests/test_engine.py b/tests/test_engine.py index 909d05e8ba3f..940968c02ca8 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,10 +1,11 @@ import asyncio +import signal import statistics import time from contextlib import contextmanager from functools import partial from typing import List -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock from uuid import uuid4 import anyio @@ -23,11 +24,13 @@ orchestrate_flow_run, orchestrate_task_run, pause_flow_run, + report_flow_run_crashes, resume_flow_run, retrieve_flow_then_begin_flow_run, ) from prefect.exceptions import ( Abort, + Cancel, CrashedRun, FailedRun, ParameterTypeError, @@ -1428,6 +1431,27 @@ def my_flow(): assert i <= 10, "`just_sleep` should not be running after timeout" + async def test_report_flow_run_crashes_handles_sigterm( + self, flow_run, orion_client, monkeypatch + ): + original_handler = lambda args: args + signal_receiver = Mock(return_value=original_handler) + monkeypatch.setattr("signal.signal", signal_receiver) + + async with report_flow_run_crashes(flow_run=flow_run, client=orion_client): + assert signal_receiver.call_args_list[0][0][0] == signal.SIGTERM + signal_handler = signal_receiver.call_args_list[0][0][1] + + # Call the signal handler and expect that it raises a `Cancel` + # exception. + with pytest.raises(Cancel): + signal_handler() + + # The original handler should be restorted when the context manager + # exits. + assert signal_receiver.call_args_list[1][0][0] == signal.SIGTERM + assert signal_receiver.call_args_list[1][0][1] == original_handler + class TestTaskRunCrashes: @pytest.mark.parametrize("interrupt_type", [KeyboardInterrupt, SystemExit]) diff --git a/tests/test_states.py b/tests/test_states.py index 8cf6ee91a497..ce0a8cfbe13b 100644 --- a/tests/test_states.py +++ b/tests/test_states.py @@ -3,7 +3,8 @@ import pytest from prefect import flow -from prefect.exceptions import CancelledRun, CrashedRun, FailedRun +from prefect.exceptions import Cancel, CancelledRun, CrashedRun, FailedRun +from prefect.orion.schemas.states import StateType from prefect.results import LiteralResult, PersistedResult, ResultFactory from prefect.states import ( Cancelled, @@ -14,6 +15,7 @@ Running, State, StateGroup, + exception_to_final_state, is_state, is_state_iterable, raise_state_exception, @@ -335,3 +337,11 @@ def test_counts_message_some_non_final(self): assert "'FAILED'=1" in counts_message assert "'CRASHED'=1" in counts_message assert "'RUNNING'=2" in counts_message + + +class TestExceptionToFinalState: + async def test_cancel_exception(self): + state = await exception_to_final_state(Cancel()) + assert state.type == StateType.CANCELLED + assert state.message == "Execution was cancelled by a termination signal." + assert isinstance(await state.result(raise_on_failure=False), Cancel) From d9a0a5ea83f3f7a6bb64a9876cbe09935d8e9e59 Mon Sep 17 00:00:00 2001 From: Michael Adkins Date: Tue, 10 Jan 2023 10:55:07 -0600 Subject: [PATCH 3/4] Refactor to use a `TerminationSignal` exception --- src/prefect/engine.py | 42 +++++++------- src/prefect/exceptions.py | 15 +++-- .../orion/orchestration/core_policy.py | 25 ++++++++ src/prefect/states.py | 13 +---- src/prefect/task_runners.py | 6 +- tests/orion/orchestration/test_core_policy.py | 57 +++++++++++++++++++ tests/test_engine.py | 8 +-- tests/test_states.py | 12 +--- 8 files changed, 123 insertions(+), 55 deletions(-) diff --git a/src/prefect/engine.py b/src/prefect/engine.py index acbd05db5463..5aefca8d06c1 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -42,13 +42,13 @@ from prefect.deployments import load_flow_from_flow_run from prefect.exceptions import ( Abort, - Cancel, FlowPauseTimeout, MappingLengthMismatch, MappingMissingIterable, NotPausedError, Pause, PausedRun, + TerminationSignal, UpstreamTaskError, ) from prefect.flows import Flow @@ -74,8 +74,8 @@ Pending, Running, State, + exception_to_crashed_state, exception_to_failed_state, - exception_to_final_state, get_state_exception, return_value_to_state, ) @@ -1570,9 +1570,14 @@ async def report_flow_run_crashes(flow_run: FlowRun, client: OrionClient): """ def cancel_flow_run(*args): - raise Cancel(cause="SIGTERM") + raise TerminationSignal(signal=signal.SIGTERM) - original_sigterm_handler = signal.signal(signal.SIGTERM, cancel_flow_run) + original_sigterm_handler = None + try: + original_sigterm_handler = signal.signal(signal.SIGTERM, cancel_flow_run) + except ValueError: + # Signals only work in the main thread + pass try: yield @@ -1580,7 +1585,7 @@ def cancel_flow_run(*args): # Do not capture internal signals as crashes raise except BaseException as exc: - state = await exception_to_final_state(exc) + state = await exception_to_crashed_state(exc) logger = flow_run_logger(flow_run) with anyio.CancelScope(shield=True): logger.error(f"Crash detected! {state.message}") @@ -1588,7 +1593,6 @@ def cancel_flow_run(*args): await client.set_flow_run_state( state=state, flow_run_id=flow_run.id, - force=True, ) engine_logger.debug( f"Reported crashed flow run {flow_run.name!r} successfully!" @@ -1597,7 +1601,8 @@ def cancel_flow_run(*args): # Reraise the exception raise exc from None finally: - signal.signal(signal.SIGTERM, original_sigterm_handler) + if original_sigterm_handler is not None: + signal.signal(signal.SIGTERM, original_sigterm_handler) @asynccontextmanager @@ -1613,11 +1618,8 @@ async def report_task_run_crashes(task_run: TaskRun, client: OrionClient): except (Abort, Pause): # Do not capture internal signals as crashes raise - except Cancel: - # Do not capture cancellations as crashes - raise except BaseException as exc: - state = await exception_to_final_state(exc) + state = await exception_to_crashed_state(exc) logger = task_run_logger(task_run) with anyio.CancelScope(shield=True): logger.error(f"Crash detected! {state.message}") @@ -1920,18 +1922,16 @@ def should_log_prints(flow_or_task: Union[Flow, Task]) -> bool: f"Engine execution of flow run '{flow_run_id}' is paused: {exc}" ) exit(0) - except Cancel as exc: + except TerminationSignal as exc: engine_logger.info( - f"Engine execution of flow run '{flow_run_id}' cancelled by orchestrator: {exc}" + f"Engine execution of flow run '{flow_run_id}' cancelled by external signal: {exc}" ) - if exc.cause == "SIGTERM": - # The default SIGTERM handler is swapped out during a flow run to - # raise this `Cancel` exception. This `os.kill` call ensures that - # the previous handler (likely the Python default) gets called as - # well. - os.kill(os.getpid(), signal.SIGTERM) - else: - exit(0) + + # Termination signals are swapped out during a flow run to perform a + # graceful shutdown and raise this exception. This `os.kill` call + # ensures that the previous handler, likely the Python default, gets + # called as well. + os.kill(os.getpid(), exc.signal) except Exception: engine_logger.error( f"Engine execution of flow run '{flow_run_id}' exited with unexpected " diff --git a/src/prefect/exceptions.py b/src/prefect/exceptions.py index a6a3ccd8788e..74f359693847 100644 --- a/src/prefect/exceptions.py +++ b/src/prefect/exceptions.py @@ -276,16 +276,19 @@ class Pause(PrefectSignal): """ -class Cancel(PrefectSignal): +class ExternalSignal(BaseException): + """ + Base type for external signal-like exceptions that should never be caught by users. """ - Raised when a flow run is cancelled via a SIGTERM - Indicates that the run should exit immediately. + +class TerminationSignal(ExternalSignal): + """ + Raised when a flow run receives a termination signal. """ - def __init__(self, *args, cause="unknown"): - super().__init__(args) - self.cause = cause + def __init__(self, signal: int): + self.signal = signal class PrefectHTTPStatusError(HTTPStatusError): diff --git a/src/prefect/orion/orchestration/core_policy.py b/src/prefect/orion/orchestration/core_policy.py index f6a97fc6c438..b80f4bc92f18 100644 --- a/src/prefect/orion/orchestration/core_policy.py +++ b/src/prefect/orion/orchestration/core_policy.py @@ -41,6 +41,7 @@ class CoreFlowPolicy(BaseOrchestrationPolicy): def priority(): return [ HandleFlowTerminalStateTransitions, + HandleCancellingStateTransitions, PreventRedundantTransitions, HandlePausingFlows, HandleResumingPausedFlows, @@ -786,3 +787,27 @@ async def before_transition( await self.abort_transition( reason=f"The enclosing flow must be running to begin task execution.", ) + + +class HandleCancellingStateTransitions(BaseOrchestrationRule): + """ + Rejects transitions from Cancelling to any terminal state except for Cancelled. + """ + + FROM_STATES = {StateType.CANCELLED, StateType.CANCELLING} + TO_STATES = ALL_ORCHESTRATION_STATES - {StateType.CANCELLED} + + async def before_transition( + self, + initial_state: Optional[states.State], + proposed_state: Optional[states.State], + context: TaskOrchestrationContext, + ) -> None: + await self.reject_transition( + state=None, + reason=( + "Cannot transition flows that are cancelling to a state other " + "than Cancelled." + ), + ) + return diff --git a/src/prefect/states.py b/src/prefect/states.py index 1eafc6f33a10..50d90bf0e707 100644 --- a/src/prefect/states.py +++ b/src/prefect/states.py @@ -17,7 +17,6 @@ result_from_state_with_data_document, ) from prefect.exceptions import ( - Cancel, CancelledRun, CrashedRun, FailedRun, @@ -125,7 +124,7 @@ def format_exception(exc: BaseException, tb: TracebackType = None) -> str: return formatted -async def exception_to_final_state( +async def exception_to_crashed_state( exc: BaseException, result_factory: Optional[ResultFactory] = None, ) -> State: @@ -134,15 +133,9 @@ async def exception_to_final_state( 'Crash' exception with a 'Crashed' state. """ state_message = None - state_cls = Crashed - if isinstance(exc, Cancel): - state_message = "Execution was cancelled by a termination signal." - state_cls = Cancelled - - elif isinstance(exc, anyio.get_cancelled_exc_class()): + if isinstance(exc, anyio.get_cancelled_exc_class()): state_message = "Execution was cancelled by the runtime environment." - state_cls = Cancelled elif isinstance(exc, KeyboardInterrupt): state_message = "Execution was aborted by an interrupt signal." @@ -170,7 +163,7 @@ async def exception_to_final_state( # from the API data = exc - return state_cls(message=state_message, data=data) + return Crashed(message=state_message, data=data) async def exception_to_failed_state( diff --git a/src/prefect/task_runners.py b/src/prefect/task_runners.py index 5b7d702b4197..3625925f0fb2 100644 --- a/src/prefect/task_runners.py +++ b/src/prefect/task_runners.py @@ -76,7 +76,7 @@ from prefect.logging import get_logger from prefect.orion.schemas.states import State -from prefect.states import exception_to_final_state +from prefect.states import exception_to_crashed_state from prefect.utilities.collections import AutoEnum T = TypeVar("T", bound="BaseTaskRunner") @@ -206,7 +206,7 @@ async def submit( try: result = await call() except BaseException as exc: - result = await exception_to_final_state(exc) + result = await exception_to_crashed_state(exc) self._results[key] = result @@ -294,7 +294,7 @@ async def _run_and_store_result( try: self._results[key] = await call() except BaseException as exc: - self._results[key] = await exception_to_final_state(exc) + self._results[key] = await exception_to_crashed_state(exc) self._result_events[key].set() diff --git a/tests/orion/orchestration/test_core_policy.py b/tests/orion/orchestration/test_core_policy.py index 33db69cb1db6..f2f4fe231511 100644 --- a/tests/orion/orchestration/test_core_policy.py +++ b/tests/orion/orchestration/test_core_policy.py @@ -15,6 +15,7 @@ CacheInsertion, CacheRetrieval, CopyScheduledTime, + HandleCancellingStateTransitions, HandleFlowTerminalStateTransitions, HandlePausingFlows, HandleResumingPausedFlows, @@ -2541,3 +2542,59 @@ async def test_prevents_tasks_From_running( else: assert ctx.response_status == SetStateStatus.ABORT assert ctx.validated_state.is_pending() + + +class TestHandleCancellingStateTransitions: + @pytest.mark.parametrize( + "proposed_state_type", + sorted( + list(set(ALL_ORCHESTRATION_STATES) - {states.StateType.CANCELLED, None}) + ), + ) + async def test_rejects_cancelling_to_anything_but_cancelled( + self, + session, + initialize_orchestration, + proposed_state_type, + ): + initial_state_type = states.StateType.CANCELLING + intended_transition = (initial_state_type, proposed_state_type) + + ctx = await initialize_orchestration( + session, + "flow", + *intended_transition, + ) + + async with HandleCancellingStateTransitions(ctx, *intended_transition) as ctx: + await ctx.validate_proposed_state() + + assert ctx.response_status == SetStateStatus.REJECT + assert ctx.validated_state_type == states.StateType.CANCELLING + + @pytest.mark.parametrize( + "proposed_state_type", + sorted( + list(set(ALL_ORCHESTRATION_STATES) - {states.StateType.CANCELLED, None}) + ), + ) + async def test_rejects_cancelled_cancelling_to_anything_but_cancelled( + self, + session, + initialize_orchestration, + proposed_state_type, + ): + initial_state_type = states.StateType.CANCELLING + intended_transition = (initial_state_type, proposed_state_type) + + ctx = await initialize_orchestration( + session, + "flow", + *intended_transition, + ) + + async with HandleCancellingStateTransitions(ctx, *intended_transition) as ctx: + await ctx.validate_proposed_state() + + assert ctx.response_status == SetStateStatus.REJECT + assert ctx.validated_state_type == states.StateType.CANCELLING diff --git a/tests/test_engine.py b/tests/test_engine.py index 940968c02ca8..a69e4e1b1a0e 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -30,13 +30,13 @@ ) from prefect.exceptions import ( Abort, - Cancel, CrashedRun, FailedRun, ParameterTypeError, Pause, PausedRun, SignatureMismatchError, + TerminationSignal, ) from prefect.futures import PrefectFuture from prefect.orion.schemas.actions import FlowRunCreate @@ -1442,9 +1442,9 @@ async def test_report_flow_run_crashes_handles_sigterm( assert signal_receiver.call_args_list[0][0][0] == signal.SIGTERM signal_handler = signal_receiver.call_args_list[0][0][1] - # Call the signal handler and expect that it raises a `Cancel` - # exception. - with pytest.raises(Cancel): + # Call the signal handler and expect that it raises a + # `TerminationSignal` exception. + with pytest.raises(TerminationSignal): signal_handler() # The original handler should be restorted when the context manager diff --git a/tests/test_states.py b/tests/test_states.py index ce0a8cfbe13b..8cf6ee91a497 100644 --- a/tests/test_states.py +++ b/tests/test_states.py @@ -3,8 +3,7 @@ import pytest from prefect import flow -from prefect.exceptions import Cancel, CancelledRun, CrashedRun, FailedRun -from prefect.orion.schemas.states import StateType +from prefect.exceptions import CancelledRun, CrashedRun, FailedRun from prefect.results import LiteralResult, PersistedResult, ResultFactory from prefect.states import ( Cancelled, @@ -15,7 +14,6 @@ Running, State, StateGroup, - exception_to_final_state, is_state, is_state_iterable, raise_state_exception, @@ -337,11 +335,3 @@ def test_counts_message_some_non_final(self): assert "'FAILED'=1" in counts_message assert "'CRASHED'=1" in counts_message assert "'RUNNING'=2" in counts_message - - -class TestExceptionToFinalState: - async def test_cancel_exception(self): - state = await exception_to_final_state(Cancel()) - assert state.type == StateType.CANCELLED - assert state.message == "Execution was cancelled by a termination signal." - assert isinstance(await state.result(raise_on_failure=False), Cancel) From 12ea3f44e086aa9e2ccee54b55ef894944f651b0 Mon Sep 17 00:00:00 2001 From: Michael Adkins Date: Tue, 10 Jan 2023 10:55:34 -0600 Subject: [PATCH 4/4] Incorporate PR feedback --- src/prefect/engine.py | 27 +++++++++++++-------------- tests/test_engine.py | 28 +++++++++++----------------- 2 files changed, 24 insertions(+), 31 deletions(-) diff --git a/src/prefect/engine.py b/src/prefect/engine.py index 5aefca8d06c1..d9efe4719bb0 100644 --- a/src/prefect/engine.py +++ b/src/prefect/engine.py @@ -16,6 +16,7 @@ See `orchestrate_flow_run`, `orchestrate_task_run` """ import logging +import os import signal import sys from contextlib import AsyncExitStack, asynccontextmanager, nullcontext @@ -1572,9 +1573,9 @@ async def report_flow_run_crashes(flow_run: FlowRun, client: OrionClient): def cancel_flow_run(*args): raise TerminationSignal(signal=signal.SIGTERM) - original_sigterm_handler = None + original_term_handler = None try: - original_sigterm_handler = signal.signal(signal.SIGTERM, cancel_flow_run) + original_term_handler = signal.signal(signal.SIGTERM, cancel_flow_run) except ValueError: # Signals only work in the main thread pass @@ -1598,11 +1599,19 @@ def cancel_flow_run(*args): f"Reported crashed flow run {flow_run.name!r} successfully!" ) + if isinstance(exc, TerminationSignal): + # Termination signals are swapped out during a flow run to perform + # a graceful shutdown and raise this exception. This `os.kill` call + # ensures that the previous handler, likely the Python default, + # gets called as well. + signal.signal(exc.signal, original_term_handler) + os.kill(os.getpid(), exc.signal) + # Reraise the exception raise exc from None finally: - if original_sigterm_handler is not None: - signal.signal(signal.SIGTERM, original_sigterm_handler) + if original_term_handler is not None: + signal.signal(signal.SIGTERM, original_term_handler) @asynccontextmanager @@ -1922,16 +1931,6 @@ def should_log_prints(flow_or_task: Union[Flow, Task]) -> bool: f"Engine execution of flow run '{flow_run_id}' is paused: {exc}" ) exit(0) - except TerminationSignal as exc: - engine_logger.info( - f"Engine execution of flow run '{flow_run_id}' cancelled by external signal: {exc}" - ) - - # Termination signals are swapped out during a flow run to perform a - # graceful shutdown and raise this exception. This `os.kill` call - # ensures that the previous handler, likely the Python default, gets - # called as well. - os.kill(os.getpid(), exc.signal) except Exception: engine_logger.error( f"Engine execution of flow run '{flow_run_id}' exited with unexpected " diff --git a/tests/test_engine.py b/tests/test_engine.py index a69e4e1b1a0e..882a23d7c811 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,4 +1,5 @@ import asyncio +import os import signal import statistics import time @@ -1434,23 +1435,16 @@ def my_flow(): async def test_report_flow_run_crashes_handles_sigterm( self, flow_run, orion_client, monkeypatch ): - original_handler = lambda args: args - signal_receiver = Mock(return_value=original_handler) - monkeypatch.setattr("signal.signal", signal_receiver) - - async with report_flow_run_crashes(flow_run=flow_run, client=orion_client): - assert signal_receiver.call_args_list[0][0][0] == signal.SIGTERM - signal_handler = signal_receiver.call_args_list[0][0][1] - - # Call the signal handler and expect that it raises a - # `TerminationSignal` exception. - with pytest.raises(TerminationSignal): - signal_handler() - - # The original handler should be restorted when the context manager - # exits. - assert signal_receiver.call_args_list[1][0][0] == signal.SIGTERM - assert signal_receiver.call_args_list[1][0][1] == original_handler + original_handler = Mock() + signal.signal(signal.SIGTERM, original_handler) + + with pytest.raises(TerminationSignal): + async with report_flow_run_crashes(flow_run=flow_run, client=orion_client): + assert signal.getsignal(signal.SIGTERM) != original_handler + os.kill(os.getpid(), signal.SIGTERM) + + original_handler.assert_called_once() + assert signal.getsignal(signal.SIGTERM) == original_handler class TestTaskRunCrashes: