From 38bf36550164b36c14a9ca6024e5c805eadc152b Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 11 Aug 2022 17:46:16 +0100 Subject: [PATCH] Harden preamble of Worker.execute against race conditions --- distributed/tests/test_cancelled_state.py | 65 +++++++++++++++++++++++ distributed/tests/test_worker.py | 55 +++++++++++++++++++ distributed/utils_test.py | 65 +++++++++++++++++++++++ distributed/worker.py | 41 +++++++------- distributed/worker_state_machine.py | 30 +++-------- 5 files changed, 212 insertions(+), 44 deletions(-) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index a029e29b9e8..9080c1d77f7 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -8,6 +8,7 @@ from distributed import Event, Lock, Worker from distributed.client import wait from distributed.utils_test import ( + BlockedExecute, BlockedGatherDep, BlockedGetData, _LockedCommPool, @@ -822,3 +823,67 @@ def test_workerstate_resumed_waiting_to_flight(ws): GatherDep(worker=ws2, to_gather={"x"}, stimulus_id="s1", total_nbytes=1), ] assert ws.tasks["x"].state == "flight" + + +@pytest.mark.parametrize("critical_section", ["execute", "deserialize_task"]) +@pytest.mark.parametrize("resume_inside_critical_section", [False, True]) +@pytest.mark.parametrize("resumed_status", ["executing", "resumed"]) +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_execute_preamble_early_cancel( + c, s, b, critical_section, resume_inside_critical_section, resumed_status +): + """Test multiple race conditions in the preamble of Worker.execute(), which used to + cause a task to remain permanently in resumed state or to crash the worker through + `fail_hard` in case of very tight timings when resuming a task. + + See also + -------- + https://github.com/dask/distributed/issues/6869 + https://github.com/dask/dask/issues/9330 + test_worker.py::test_execute_preamble_abort_retirement + """ + async with BlockedExecute(s.address, validate=True) as a: + if critical_section == "execute": + in_ev = a.in_execute + block_ev = a.block_execute + a.block_deserialize_task.set() + else: + assert critical_section == "deserialize_task" + in_ev = a.in_deserialize_task + block_ev = a.block_deserialize_task + a.block_execute.set() + + async def resume(): + if resumed_status == "executing": + x = c.submit(inc, 1, key="x", workers=[a.address]) + await wait_for_state("x", "executing", a) + return x, 2 + else: + assert resumed_status == "resumed" + x = c.submit(inc, 1, key="x", workers=[b.address]) + y = c.submit(inc, x, key="y", workers=[a.address]) + await wait_for_state("x", "resumed", a) + return y, 3 + + x = c.submit(inc, 1, key="x", workers=[a.address]) + await in_ev.wait() + + x.release() + await wait_for_state("x", "cancelled", a) + + if resume_inside_critical_section: + fut, expect = await resume() + + # Unblock Worker.execute. At the moment of writing this test, the method + # would detect the cancelled status and perform an early exit. + block_ev.set() + await a.in_execute_exit.wait() + + if not resume_inside_critical_section: + fut, expect = await resume() + + # Finally let the done_callback of Worker.execute run + a.block_execute_exit.set() + + # Test that x does not get stuck. + assert await fut == expect diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 3d7d2c0105b..7e47792ba41 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -49,6 +49,7 @@ from distributed.protocol import pickle from distributed.scheduler import Scheduler from distributed.utils_test import ( + BlockedExecute, BlockedGatherDep, BlockedGetData, TaskStateMetadataPlugin, @@ -3500,3 +3501,57 @@ def teardown(self, worker): async with Worker(s.address) as worker: assert await c.submit(inc, 1) == 2 assert worker.plugins[InitWorkerNewThread.name].setup_status is Status.running + + +@gen_cluster( + client=True, + nthreads=[], + config={ + # This is just to make Scheduler.retire_worker more reactive to changes + "distributed.scheduler.active-memory-manager.start": True, + "distributed.scheduler.active-memory-manager.interval": "50ms", + }, +) +async def test_execute_preamble_abort_retirement(c, s): + """Test race condition in the preamble of Worker.execute(), which used to cause a + task to remain permanently in executing state in case of very tight timings when + exiting the closing_gracefully status. + + See also + -------- + https://github.com/dask/distributed/issues/6867 + test_cancelled_state.py::test_execute_preamble_early_cancel + """ + async with BlockedExecute(s.address) as a: + await c.wait_for_workers(1) + a.block_deserialize_task.set() # Uninteresting in this test + + x = await c.scatter({"x": 1}, workers=[a.address]) + y = c.submit(inc, 1, key="y", workers=[a.address]) + await a.in_execute.wait() + + async with BlockedGatherDep(s.address) as b: + await c.wait_for_workers(2) + retire_fut = asyncio.create_task(c.retire_workers([a.address])) + while a.status != Status.closing_gracefully: + await asyncio.sleep(0.01) + # The Active Memory Manager will send to b the message + # {op: acquire-replicas, who_has: {x: [a.address]}} + await b.in_gather_dep.wait() + + # Run Worker.execute. At the moment of writing this test, the method would + # detect the closing_gracefully status and perform an early exit. + a.block_execute.set() + await a.in_execute_exit.wait() + + # b has shut down. There's nowhere to replicate x to anymore, so retire_workers + # will give up and reinstate a to running status. + assert await retire_fut == {} + while a.status != Status.running: + await asyncio.sleep(0.01) + + # Finally let the done_callback of Worker.execute run + a.block_execute_exit.set() + + # Test that y does not get stuck. + assert await y == 2 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 5fb8c8d832c..d7e0f3e9196 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2154,6 +2154,7 @@ async def test1(s, a, b): See also -------- BlockedGetData + BlockedExecute """ def __init__(self, *args, **kwargs): @@ -2175,6 +2176,7 @@ class BlockedGetData(Worker): See also -------- BlockedGatherDep + BlockedExecute """ def __init__(self, *args, **kwargs): @@ -2188,6 +2190,69 @@ async def get_data(self, comm, *args, **kwargs): return await super().get_data(comm, *args, **kwargs) +class BlockedExecute(Worker): + """A Worker that sets event `in_execute` the first time it enters the execute + method and then does not proceed, thus leaving the task in executing state + indefinitely, until the test sets `block_execute`. + + After that, the worker sets `in_deserialize_task` to simulate the moment when a + large run_spec is being deserialized in a separate thread. The worker will block + again until the test sets `block_deserialize_task`. + + Finally, the worker sets `in_execute_exit` when execute() terminates, but before the + worker state has processed its exit callback. The worker will block one last time + until the test sets `block_execute_exit`. + + Note + ---- + In the vast majority of the test cases, it is simpler and more readable to just + submit to a regular Worker a task that blocks on a distributed.Event: + + .. code-block:: python + + def f(in_task, block_task): + in_task.set() + block_task.wait() + + in_task = distributed.Event() + block_task = distributed.Event() + fut = c.submit(f, in_task, block_task) + await in_task.wait() + await block_task.set() + + See also + -------- + BlockedGatherDep + BlockedGetData + """ + + def __init__(self, *args, **kwargs): + self.in_execute = asyncio.Event() + self.block_execute = asyncio.Event() + self.in_deserialize_task = asyncio.Event() + self.block_deserialize_task = asyncio.Event() + self.in_execute_exit = asyncio.Event() + self.block_execute_exit = asyncio.Event() + + super().__init__(*args, **kwargs) + + async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: + self.in_execute.set() + await self.block_execute.wait() + try: + return await super().execute(key, stimulus_id=stimulus_id) + finally: + self.in_execute_exit.set() + await self.block_execute_exit.wait() + + async def _maybe_deserialize_task( + self, ts: WorkerTaskState + ) -> tuple[Callable, tuple, dict[str, Any]]: + self.in_deserialize_task.set() + await self.block_deserialize_task.wait() + return await super()._maybe_deserialize_task(ts) + + @contextmanager def freeze_data_fetching(w: Worker, *, jump_start: bool = False) -> Iterator[None]: """Prevent any task from transitioning from fetch to flight on the worker while diff --git a/distributed/worker.py b/distributed/worker.py index 59576cd59ec..86e1fba5316 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -108,7 +108,6 @@ from distributed.worker_state_machine import ( NO_VALUE, AcquireReplicasEvent, - AlreadyCancelledEvent, BaseWorker, CancelComputeEvent, ComputeTaskEvent, @@ -1847,9 +1846,7 @@ def _(**kwargs): return _ @fail_hard - def _handle_stimulus_from_task( - self, task: asyncio.Task[StateMachineEvent | None] - ) -> None: + def _handle_stimulus_from_task(self, task: asyncio.Task[StateMachineEvent]) -> None: """Override BaseWorker method for added validation See also @@ -1968,7 +1965,7 @@ async def gather_dep( total_nbytes: int, *, stimulus_id: str, - ) -> StateMachineEvent | None: + ) -> StateMachineEvent: """Implements BaseWorker abstract method See also @@ -1976,7 +1973,14 @@ async def gather_dep( distributed.worker_state_machine.BaseWorker.gather_dep """ if self.status not in WORKER_ANY_RUNNING: - return None + # This is only for the sake of coherence of the WorkerState; + # it should never actually reach the scheduler. + return GatherDepFailureEvent.from_exception( + RuntimeError("Worker is shutting down"), + worker=worker, + total_nbytes=total_nbytes, + stimulus_id=f"worker-closing-{time()}", + ) try: self.state.log.append( @@ -2044,7 +2048,7 @@ async def gather_dep( stimulus_id=f"gather-dep-failure-{time()}", ) - async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None: + async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent: """Wait some time, then take a peer worker out of busy state. Implements BaseWorker abstract method. @@ -2137,24 +2141,21 @@ async def _maybe_deserialize_task( return function, args, kwargs @fail_hard - async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None: + async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: """Execute a task. Implements BaseWorker abstract method. See also -------- distributed.worker_state_machine.BaseWorker.execute """ - if self.status in {Status.closing, Status.closed, Status.closing_gracefully}: - return None - ts = self.state.tasks.get(key) - if not ts: - return None - if ts.state == "cancelled": - logger.debug( - "Trying to execute task %s which is not in executing state anymore", - ts, - ) - return AlreadyCancelledEvent(key=ts.key, stimulus_id=stimulus_id) + if self.status not in WORKER_ANY_RUNNING: + # This is just for internal coherence of the WorkerState; the reschedule + # message should not ever reach the Scheduler. + # It is still OK if it does though. + return RescheduleEvent(key=key, stimulus_id=f"worker-closing-{time()}") + + # The key *must* be in the worker state thanks to the cancelled state + ts = self.state.tasks[key] try: function, args, kwargs = await self._maybe_deserialize_task(ts) @@ -2169,7 +2170,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No try: if self.state.validate: assert not ts.waiting_for_data - assert ts.state == "executing", ts.state + assert ts.state in ("executing", "cancelled", "resumed"), ts assert ts.run_spec is not None args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 6eac8215839..094b90a5498 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -917,12 +917,6 @@ class CancelComputeEvent(StateMachineEvent): key: str -@dataclass -class AlreadyCancelledEvent(StateMachineEvent): - __slots__ = ("key",) - key: str - - # Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception @dataclass class RescheduleEvent(StateMachineEvent): @@ -2931,16 +2925,6 @@ def _handle_cancel_compute(self, ev: CancelComputeEvent) -> RecsInstrs: assert not ts.dependents return {ts: "released"}, [] - @_handle_event.register - def _handle_already_cancelled(self, ev: AlreadyCancelledEvent) -> RecsInstrs: - """Task is already cancelled by the time execute() runs""" - # key *must* be still in tasks. Releasing it directly is forbidden - # without going through cancelled - ts = self.tasks.get(ev.key) - assert ts, self.story(ev.key) - ts.done = True - return {ts: "released"}, [] - @_handle_event.register def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs: """Task completed successfully""" @@ -3332,18 +3316,16 @@ def __init__(self, state: WorkerState): self.state = state self._async_instructions = set() - def _handle_stimulus_from_task( - self, task: asyncio.Task[StateMachineEvent | None] - ) -> None: + def _handle_stimulus_from_task(self, task: asyncio.Task[StateMachineEvent]) -> None: """An asynchronous instruction just completed; process the returned stimulus.""" self._async_instructions.remove(task) try: # This *should* never raise any other exceptions stim = task.result() except asyncio.CancelledError: + # This should exclusively happen in Worker.close() return - if stim: - self.handle_stimulus(stim) + self.handle_stimulus(stim) def handle_stimulus(self, *stims: StateMachineEvent) -> None: """Forward one or more external stimuli to :meth:`WorkerState.handle_stimulus` @@ -3432,7 +3414,7 @@ async def gather_dep( total_nbytes: int, *, stimulus_id: str, - ) -> StateMachineEvent | None: + ) -> StateMachineEvent: """Gather dependencies for a task from a worker who has them Parameters @@ -3449,12 +3431,12 @@ async def gather_dep( ... @abc.abstractmethod - async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None: + async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: """Execute a task""" ... @abc.abstractmethod - async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None: + async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent: """Wait some time, then take a peer worker out of busy state""" ...