Skip to content

Commit

Permalink
Harden preamble of Worker.execute against race conditions (#6878)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Aug 12, 2022
1 parent 75ef3c9 commit 832b280
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 44 deletions.
65 changes: 65 additions & 0 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from distributed import Event, Lock, Worker
from distributed.client import wait
from distributed.utils_test import (
BlockedExecute,
BlockedGatherDep,
BlockedGetData,
_LockedCommPool,
Expand Down Expand Up @@ -822,3 +823,67 @@ def test_workerstate_resumed_waiting_to_flight(ws):
GatherDep(worker=ws2, to_gather={"x"}, stimulus_id="s1", total_nbytes=1),
]
assert ws.tasks["x"].state == "flight"


@pytest.mark.parametrize("critical_section", ["execute", "deserialize_task"])
@pytest.mark.parametrize("resume_inside_critical_section", [False, True])
@pytest.mark.parametrize("resumed_status", ["executing", "resumed"])
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_execute_preamble_early_cancel(
c, s, b, critical_section, resume_inside_critical_section, resumed_status
):
"""Test multiple race conditions in the preamble of Worker.execute(), which used to
cause a task to remain permanently in resumed state or to crash the worker through
`fail_hard` in case of very tight timings when resuming a task.
See also
--------
https://github.com/dask/distributed/issues/6869
https://github.com/dask/dask/issues/9330
test_worker.py::test_execute_preamble_abort_retirement
"""
async with BlockedExecute(s.address, validate=True) as a:
if critical_section == "execute":
in_ev = a.in_execute
block_ev = a.block_execute
a.block_deserialize_task.set()
else:
assert critical_section == "deserialize_task"
in_ev = a.in_deserialize_task
block_ev = a.block_deserialize_task
a.block_execute.set()

async def resume():
if resumed_status == "executing":
x = c.submit(inc, 1, key="x", workers=[a.address])
await wait_for_state("x", "executing", a)
return x, 2
else:
assert resumed_status == "resumed"
x = c.submit(inc, 1, key="x", workers=[b.address])
y = c.submit(inc, x, key="y", workers=[a.address])
await wait_for_state("x", "resumed", a)
return y, 3

x = c.submit(inc, 1, key="x", workers=[a.address])
await in_ev.wait()

x.release()
await wait_for_state("x", "cancelled", a)

if resume_inside_critical_section:
fut, expect = await resume()

# Unblock Worker.execute. At the moment of writing this test, the method
# would detect the cancelled status and perform an early exit.
block_ev.set()
await a.in_execute_exit.wait()

if not resume_inside_critical_section:
fut, expect = await resume()

# Finally let the done_callback of Worker.execute run
a.block_execute_exit.set()

# Test that x does not get stuck.
assert await fut == expect
55 changes: 55 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
from distributed.utils_test import (
BlockedExecute,
BlockedGatherDep,
BlockedGetData,
TaskStateMetadataPlugin,
Expand Down Expand Up @@ -3500,3 +3501,57 @@ def teardown(self, worker):
async with Worker(s.address) as worker:
assert await c.submit(inc, 1) == 2
assert worker.plugins[InitWorkerNewThread.name].setup_status is Status.running


@gen_cluster(
client=True,
nthreads=[],
config={
# This is just to make Scheduler.retire_worker more reactive to changes
"distributed.scheduler.active-memory-manager.start": True,
"distributed.scheduler.active-memory-manager.interval": "50ms",
},
)
async def test_execute_preamble_abort_retirement(c, s):
"""Test race condition in the preamble of Worker.execute(), which used to cause a
task to remain permanently in executing state in case of very tight timings when
exiting the closing_gracefully status.
See also
--------
https://github.com/dask/distributed/issues/6867
test_cancelled_state.py::test_execute_preamble_early_cancel
"""
async with BlockedExecute(s.address) as a:
await c.wait_for_workers(1)
a.block_deserialize_task.set() # Uninteresting in this test

x = await c.scatter({"x": 1}, workers=[a.address])
y = c.submit(inc, 1, key="y", workers=[a.address])
await a.in_execute.wait()

async with BlockedGatherDep(s.address) as b:
await c.wait_for_workers(2)
retire_fut = asyncio.create_task(c.retire_workers([a.address]))
while a.status != Status.closing_gracefully:
await asyncio.sleep(0.01)
# The Active Memory Manager will send to b the message
# {op: acquire-replicas, who_has: {x: [a.address]}}
await b.in_gather_dep.wait()

# Run Worker.execute. At the moment of writing this test, the method would
# detect the closing_gracefully status and perform an early exit.
a.block_execute.set()
await a.in_execute_exit.wait()

# b has shut down. There's nowhere to replicate x to anymore, so retire_workers
# will give up and reinstate a to running status.
assert await retire_fut == {}
while a.status != Status.running:
await asyncio.sleep(0.01)

# Finally let the done_callback of Worker.execute run
a.block_execute_exit.set()

# Test that y does not get stuck.
assert await y == 2
65 changes: 65 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2154,6 +2154,7 @@ async def test1(s, a, b):
See also
--------
BlockedGetData
BlockedExecute
"""

def __init__(self, *args, **kwargs):
Expand All @@ -2175,6 +2176,7 @@ class BlockedGetData(Worker):
See also
--------
BlockedGatherDep
BlockedExecute
"""

def __init__(self, *args, **kwargs):
Expand All @@ -2188,6 +2190,69 @@ async def get_data(self, comm, *args, **kwargs):
return await super().get_data(comm, *args, **kwargs)


class BlockedExecute(Worker):
"""A Worker that sets event `in_execute` the first time it enters the execute
method and then does not proceed, thus leaving the task in executing state
indefinitely, until the test sets `block_execute`.
After that, the worker sets `in_deserialize_task` to simulate the moment when a
large run_spec is being deserialized in a separate thread. The worker will block
again until the test sets `block_deserialize_task`.
Finally, the worker sets `in_execute_exit` when execute() terminates, but before the
worker state has processed its exit callback. The worker will block one last time
until the test sets `block_execute_exit`.
Note
----
In the vast majority of the test cases, it is simpler and more readable to just
submit to a regular Worker a task that blocks on a distributed.Event:
.. code-block:: python
def f(in_task, block_task):
in_task.set()
block_task.wait()
in_task = distributed.Event()
block_task = distributed.Event()
fut = c.submit(f, in_task, block_task)
await in_task.wait()
await block_task.set()
See also
--------
BlockedGatherDep
BlockedGetData
"""

def __init__(self, *args, **kwargs):
self.in_execute = asyncio.Event()
self.block_execute = asyncio.Event()
self.in_deserialize_task = asyncio.Event()
self.block_deserialize_task = asyncio.Event()
self.in_execute_exit = asyncio.Event()
self.block_execute_exit = asyncio.Event()

super().__init__(*args, **kwargs)

async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
self.in_execute.set()
await self.block_execute.wait()
try:
return await super().execute(key, stimulus_id=stimulus_id)
finally:
self.in_execute_exit.set()
await self.block_execute_exit.wait()

async def _maybe_deserialize_task(
self, ts: WorkerTaskState
) -> tuple[Callable, tuple, dict[str, Any]]:
self.in_deserialize_task.set()
await self.block_deserialize_task.wait()
return await super()._maybe_deserialize_task(ts)


@contextmanager
def freeze_data_fetching(w: Worker, *, jump_start: bool = False) -> Iterator[None]:
"""Prevent any task from transitioning from fetch to flight on the worker while
Expand Down
41 changes: 21 additions & 20 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@
from distributed.worker_state_machine import (
NO_VALUE,
AcquireReplicasEvent,
AlreadyCancelledEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
Expand Down Expand Up @@ -1847,9 +1846,7 @@ def _(**kwargs):
return _

@fail_hard
def _handle_stimulus_from_task(
self, task: asyncio.Task[StateMachineEvent | None]
) -> None:
def _handle_stimulus_from_task(self, task: asyncio.Task[StateMachineEvent]) -> None:
"""Override BaseWorker method for added validation
See also
Expand Down Expand Up @@ -1968,15 +1965,22 @@ async def gather_dep(
total_nbytes: int,
*,
stimulus_id: str,
) -> StateMachineEvent | None:
) -> StateMachineEvent:
"""Implements BaseWorker abstract method
See also
--------
distributed.worker_state_machine.BaseWorker.gather_dep
"""
if self.status not in WORKER_ANY_RUNNING:
return None
# This is only for the sake of coherence of the WorkerState;
# it should never actually reach the scheduler.
return GatherDepFailureEvent.from_exception(
RuntimeError("Worker is shutting down"),
worker=worker,
total_nbytes=total_nbytes,
stimulus_id=f"worker-closing-{time()}",
)

try:
self.state.log.append(
Expand Down Expand Up @@ -2044,7 +2048,7 @@ async def gather_dep(
stimulus_id=f"gather-dep-failure-{time()}",
)

async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None:
async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent:
"""Wait some time, then take a peer worker out of busy state.
Implements BaseWorker abstract method.
Expand Down Expand Up @@ -2137,24 +2141,21 @@ async def _maybe_deserialize_task(
return function, args, kwargs

@fail_hard
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None:
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
"""Execute a task. Implements BaseWorker abstract method.
See also
--------
distributed.worker_state_machine.BaseWorker.execute
"""
if self.status in {Status.closing, Status.closed, Status.closing_gracefully}:
return None
ts = self.state.tasks.get(key)
if not ts:
return None
if ts.state == "cancelled":
logger.debug(
"Trying to execute task %s which is not in executing state anymore",
ts,
)
return AlreadyCancelledEvent(key=ts.key, stimulus_id=stimulus_id)
if self.status not in WORKER_ANY_RUNNING:
# This is just for internal coherence of the WorkerState; the reschedule
# message should not ever reach the Scheduler.
# It is still OK if it does though.
return RescheduleEvent(key=key, stimulus_id=f"worker-closing-{time()}")

# The key *must* be in the worker state thanks to the cancelled state
ts = self.state.tasks[key]

try:
function, args, kwargs = await self._maybe_deserialize_task(ts)
Expand All @@ -2169,7 +2170,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
try:
if self.state.validate:
assert not ts.waiting_for_data
assert ts.state == "executing", ts.state
assert ts.state in ("executing", "cancelled", "resumed"), ts
assert ts.run_spec is not None

args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs)
Expand Down
30 changes: 6 additions & 24 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,12 +917,6 @@ class CancelComputeEvent(StateMachineEvent):
key: str


@dataclass
class AlreadyCancelledEvent(StateMachineEvent):
__slots__ = ("key",)
key: str


# Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception
@dataclass
class RescheduleEvent(StateMachineEvent):
Expand Down Expand Up @@ -2931,16 +2925,6 @@ def _handle_cancel_compute(self, ev: CancelComputeEvent) -> RecsInstrs:
assert not ts.dependents
return {ts: "released"}, []

@_handle_event.register
def _handle_already_cancelled(self, ev: AlreadyCancelledEvent) -> RecsInstrs:
"""Task is already cancelled by the time execute() runs"""
# key *must* be still in tasks. Releasing it directly is forbidden
# without going through cancelled
ts = self.tasks.get(ev.key)
assert ts, self.story(ev.key)
ts.done = True
return {ts: "released"}, []

@_handle_event.register
def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs:
"""Task completed successfully"""
Expand Down Expand Up @@ -3332,18 +3316,16 @@ def __init__(self, state: WorkerState):
self.state = state
self._async_instructions = set()

def _handle_stimulus_from_task(
self, task: asyncio.Task[StateMachineEvent | None]
) -> None:
def _handle_stimulus_from_task(self, task: asyncio.Task[StateMachineEvent]) -> None:
"""An asynchronous instruction just completed; process the returned stimulus."""
self._async_instructions.remove(task)
try:
# This *should* never raise any other exceptions
stim = task.result()
except asyncio.CancelledError:
# This should exclusively happen in Worker.close()
return
if stim:
self.handle_stimulus(stim)
self.handle_stimulus(stim)

def handle_stimulus(self, *stims: StateMachineEvent) -> None:
"""Forward one or more external stimuli to :meth:`WorkerState.handle_stimulus`
Expand Down Expand Up @@ -3432,7 +3414,7 @@ async def gather_dep(
total_nbytes: int,
*,
stimulus_id: str,
) -> StateMachineEvent | None:
) -> StateMachineEvent:
"""Gather dependencies for a task from a worker who has them
Parameters
Expand All @@ -3449,12 +3431,12 @@ async def gather_dep(
...

@abc.abstractmethod
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None:
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
"""Execute a task"""
...

@abc.abstractmethod
async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None:
async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent:
"""Wait some time, then take a peer worker out of busy state"""
...

Expand Down

0 comments on commit 832b280

Please sign in to comment.