Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harden preamble of Worker.execute against race conditions #6878

Merged
merged 1 commit into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from distributed import Event, Lock, Worker
from distributed.client import wait
from distributed.utils_test import (
BlockedExecute,
BlockedGatherDep,
BlockedGetData,
_LockedCommPool,
Expand Down Expand Up @@ -822,3 +823,67 @@ def test_workerstate_resumed_waiting_to_flight(ws):
GatherDep(worker=ws2, to_gather={"x"}, stimulus_id="s1", total_nbytes=1),
]
assert ws.tasks["x"].state == "flight"


@pytest.mark.parametrize("critical_section", ["execute", "deserialize_task"])
@pytest.mark.parametrize("resume_inside_critical_section", [False, True])
@pytest.mark.parametrize("resumed_status", ["executing", "resumed"])
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_execute_preamble_early_cancel(
c, s, b, critical_section, resume_inside_critical_section, resumed_status
):
"""Test multiple race conditions in the preamble of Worker.execute(), which used to
cause a task to remain permanently in resumed state or to crash the worker through
`fail_hard` in case of very tight timings when resuming a task.

See also
--------
https://github.com/dask/distributed/issues/6869
https://github.com/dask/dask/issues/9330
test_worker.py::test_execute_preamble_abort_retirement
"""
async with BlockedExecute(s.address, validate=True) as a:
if critical_section == "execute":
in_ev = a.in_execute
block_ev = a.block_execute
a.block_deserialize_task.set()
else:
assert critical_section == "deserialize_task"
in_ev = a.in_deserialize_task
block_ev = a.block_deserialize_task
a.block_execute.set()

async def resume():
if resumed_status == "executing":
x = c.submit(inc, 1, key="x", workers=[a.address])
await wait_for_state("x", "executing", a)
return x, 2
else:
assert resumed_status == "resumed"
x = c.submit(inc, 1, key="x", workers=[b.address])
y = c.submit(inc, x, key="y", workers=[a.address])
await wait_for_state("x", "resumed", a)
return y, 3

x = c.submit(inc, 1, key="x", workers=[a.address])
await in_ev.wait()

x.release()
await wait_for_state("x", "cancelled", a)

if resume_inside_critical_section:
fut, expect = await resume()

# Unblock Worker.execute. At the moment of writing this test, the method
# would detect the cancelled status and perform an early exit.
block_ev.set()
await a.in_execute_exit.wait()

if not resume_inside_critical_section:
fut, expect = await resume()

# Finally let the done_callback of Worker.execute run
a.block_execute_exit.set()

# Test that x does not get stuck.
assert await fut == expect
55 changes: 55 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
from distributed.utils_test import (
BlockedExecute,
BlockedGatherDep,
BlockedGetData,
TaskStateMetadataPlugin,
Expand Down Expand Up @@ -3500,3 +3501,57 @@ def teardown(self, worker):
async with Worker(s.address) as worker:
assert await c.submit(inc, 1) == 2
assert worker.plugins[InitWorkerNewThread.name].setup_status is Status.running


@gen_cluster(
client=True,
nthreads=[],
config={
# This is just to make Scheduler.retire_worker more reactive to changes
"distributed.scheduler.active-memory-manager.start": True,
"distributed.scheduler.active-memory-manager.interval": "50ms",
},
)
async def test_execute_preamble_abort_retirement(c, s):
"""Test race condition in the preamble of Worker.execute(), which used to cause a
task to remain permanently in executing state in case of very tight timings when
exiting the closing_gracefully status.

See also
--------
https://github.com/dask/distributed/issues/6867
test_cancelled_state.py::test_execute_preamble_early_cancel
"""
async with BlockedExecute(s.address) as a:
await c.wait_for_workers(1)
a.block_deserialize_task.set() # Uninteresting in this test

x = await c.scatter({"x": 1}, workers=[a.address])
y = c.submit(inc, 1, key="y", workers=[a.address])
await a.in_execute.wait()

async with BlockedGatherDep(s.address) as b:
await c.wait_for_workers(2)
retire_fut = asyncio.create_task(c.retire_workers([a.address]))
while a.status != Status.closing_gracefully:
await asyncio.sleep(0.01)
# The Active Memory Manager will send to b the message
# {op: acquire-replicas, who_has: {x: [a.address]}}
await b.in_gather_dep.wait()

# Run Worker.execute. At the moment of writing this test, the method would
# detect the closing_gracefully status and perform an early exit.
a.block_execute.set()
await a.in_execute_exit.wait()

# b has shut down. There's nowhere to replicate x to anymore, so retire_workers
# will give up and reinstate a to running status.
assert await retire_fut == {}
while a.status != Status.running:
await asyncio.sleep(0.01)

# Finally let the done_callback of Worker.execute run
a.block_execute_exit.set()

# Test that y does not get stuck.
assert await y == 2
65 changes: 65 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2154,6 +2154,7 @@ async def test1(s, a, b):
See also
--------
BlockedGetData
BlockedExecute
"""

def __init__(self, *args, **kwargs):
Expand All @@ -2175,6 +2176,7 @@ class BlockedGetData(Worker):
See also
--------
BlockedGatherDep
BlockedExecute
"""

def __init__(self, *args, **kwargs):
Expand All @@ -2188,6 +2190,69 @@ async def get_data(self, comm, *args, **kwargs):
return await super().get_data(comm, *args, **kwargs)


class BlockedExecute(Worker):
"""A Worker that sets event `in_execute` the first time it enters the execute
method and then does not proceed, thus leaving the task in executing state
indefinitely, until the test sets `block_execute`.

After that, the worker sets `in_deserialize_task` to simulate the moment when a
large run_spec is being deserialized in a separate thread. The worker will block
again until the test sets `block_deserialize_task`.

Finally, the worker sets `in_execute_exit` when execute() terminates, but before the
worker state has processed its exit callback. The worker will block one last time
until the test sets `block_execute_exit`.

Note
----
In the vast majority of the test cases, it is simpler and more readable to just
submit to a regular Worker a task that blocks on a distributed.Event:

.. code-block:: python

def f(in_task, block_task):
in_task.set()
block_task.wait()

in_task = distributed.Event()
block_task = distributed.Event()
fut = c.submit(f, in_task, block_task)
await in_task.wait()
await block_task.set()

See also
--------
BlockedGatherDep
BlockedGetData
"""

def __init__(self, *args, **kwargs):
self.in_execute = asyncio.Event()
self.block_execute = asyncio.Event()
self.in_deserialize_task = asyncio.Event()
self.block_deserialize_task = asyncio.Event()
self.in_execute_exit = asyncio.Event()
self.block_execute_exit = asyncio.Event()

super().__init__(*args, **kwargs)

async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
self.in_execute.set()
await self.block_execute.wait()
try:
return await super().execute(key, stimulus_id=stimulus_id)
finally:
self.in_execute_exit.set()
await self.block_execute_exit.wait()

async def _maybe_deserialize_task(
self, ts: WorkerTaskState
) -> tuple[Callable, tuple, dict[str, Any]]:
self.in_deserialize_task.set()
await self.block_deserialize_task.wait()
return await super()._maybe_deserialize_task(ts)


@contextmanager
def freeze_data_fetching(w: Worker, *, jump_start: bool = False) -> Iterator[None]:
"""Prevent any task from transitioning from fetch to flight on the worker while
Expand Down
41 changes: 21 additions & 20 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@
from distributed.worker_state_machine import (
NO_VALUE,
AcquireReplicasEvent,
AlreadyCancelledEvent,
BaseWorker,
CancelComputeEvent,
ComputeTaskEvent,
Expand Down Expand Up @@ -1847,9 +1846,7 @@ def _(**kwargs):
return _

@fail_hard
def _handle_stimulus_from_task(
self, task: asyncio.Task[StateMachineEvent | None]
) -> None:
def _handle_stimulus_from_task(self, task: asyncio.Task[StateMachineEvent]) -> None:
"""Override BaseWorker method for added validation

See also
Expand Down Expand Up @@ -1968,15 +1965,22 @@ async def gather_dep(
total_nbytes: int,
*,
stimulus_id: str,
) -> StateMachineEvent | None:
) -> StateMachineEvent:
"""Implements BaseWorker abstract method

See also
--------
distributed.worker_state_machine.BaseWorker.gather_dep
"""
if self.status not in WORKER_ANY_RUNNING:
return None
# This is only for the sake of coherence of the WorkerState;
# it should never actually reach the scheduler.
return GatherDepFailureEvent.from_exception(
RuntimeError("Worker is shutting down"),
worker=worker,
total_nbytes=total_nbytes,
stimulus_id=f"worker-closing-{time()}",
)

try:
self.state.log.append(
Expand Down Expand Up @@ -2044,7 +2048,7 @@ async def gather_dep(
stimulus_id=f"gather-dep-failure-{time()}",
)

async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None:
async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent:
"""Wait some time, then take a peer worker out of busy state.
Implements BaseWorker abstract method.

Expand Down Expand Up @@ -2137,24 +2141,21 @@ async def _maybe_deserialize_task(
return function, args, kwargs

@fail_hard
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None:
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
"""Execute a task. Implements BaseWorker abstract method.

See also
--------
distributed.worker_state_machine.BaseWorker.execute
"""
if self.status in {Status.closing, Status.closed, Status.closing_gracefully}:
return None
ts = self.state.tasks.get(key)
if not ts:
return None
if ts.state == "cancelled":
logger.debug(
"Trying to execute task %s which is not in executing state anymore",
ts,
)
return AlreadyCancelledEvent(key=ts.key, stimulus_id=stimulus_id)
Copy link
Collaborator Author

@crusaderky crusaderky Aug 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two new tests remain green if, instead of removing this block completely, you just replace the exit event with

return RescheduleEvent(key=key, stimulus_id=f"already-cancelled-{time()}")

They also remain green if you copy-paste the same block after the call to _maybe_deserialize_task.
In both cases, the reschedule event reaches the scheduler and behaves as expected.

if self.status not in WORKER_ANY_RUNNING:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

align execute and gather_dep

# This is just for internal coherence of the WorkerState; the reschedule
# message should not ever reach the Scheduler.
# It is still OK if it does though.
return RescheduleEvent(key=key, stimulus_id=f"worker-closing-{time()}")
Copy link
Collaborator Author

@crusaderky crusaderky Aug 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two new tests remain green if you revert the first line of this block:

if self.status in {Status.closing, Status.closed, Status.closing_gracefully}:
    return RescheduleEvent(key=key, stimulus_id=f"worker-closing-{time()}")

In this case, the reschedule event actually reaches the scheduler.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned that this causes the scheduler to reschedule a task twice. First, because the worker left, and second, because the reschedule event arrives.
Why is the previous version of returning None not sufficient?

Edit:

if worker and ts.processing_on.address != worker:
return
scheduler verifies the reschedule event +1


# The key *must* be in the worker state thanks to the cancelled state
ts = self.state.tasks[key]

try:
function, args, kwargs = await self._maybe_deserialize_task(ts)
Expand All @@ -2169,7 +2170,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
try:
if self.state.validate:
assert not ts.waiting_for_data
assert ts.state == "executing", ts.state
assert ts.state in ("executing", "cancelled", "resumed"), ts
assert ts.run_spec is not None

args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs)
Expand Down
30 changes: 6 additions & 24 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,12 +917,6 @@ class CancelComputeEvent(StateMachineEvent):
key: str


@dataclass
class AlreadyCancelledEvent(StateMachineEvent):
__slots__ = ("key",)
key: str


# Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception
@dataclass
class RescheduleEvent(StateMachineEvent):
Expand Down Expand Up @@ -2931,16 +2925,6 @@ def _handle_cancel_compute(self, ev: CancelComputeEvent) -> RecsInstrs:
assert not ts.dependents
return {ts: "released"}, []

@_handle_event.register
def _handle_already_cancelled(self, ev: AlreadyCancelledEvent) -> RecsInstrs:
"""Task is already cancelled by the time execute() runs"""
# key *must* be still in tasks. Releasing it directly is forbidden
# without going through cancelled
ts = self.tasks.get(ev.key)
assert ts, self.story(ev.key)
ts.done = True
return {ts: "released"}, []

@_handle_event.register
def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs:
"""Task completed successfully"""
Expand Down Expand Up @@ -3332,18 +3316,16 @@ def __init__(self, state: WorkerState):
self.state = state
self._async_instructions = set()

def _handle_stimulus_from_task(
self, task: asyncio.Task[StateMachineEvent | None]
) -> None:
def _handle_stimulus_from_task(self, task: asyncio.Task[StateMachineEvent]) -> None:
"""An asynchronous instruction just completed; process the returned stimulus."""
self._async_instructions.remove(task)
try:
# This *should* never raise any other exceptions
stim = task.result()
except asyncio.CancelledError:
# This should exclusively happen in Worker.close()
return
if stim:
self.handle_stimulus(stim)
self.handle_stimulus(stim)

def handle_stimulus(self, *stims: StateMachineEvent) -> None:
"""Forward one or more external stimuli to :meth:`WorkerState.handle_stimulus`
Expand Down Expand Up @@ -3432,7 +3414,7 @@ async def gather_dep(
total_nbytes: int,
*,
stimulus_id: str,
) -> StateMachineEvent | None:
) -> StateMachineEvent:
"""Gather dependencies for a task from a worker who has them

Parameters
Expand All @@ -3449,12 +3431,12 @@ async def gather_dep(
...

@abc.abstractmethod
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None:
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
"""Execute a task"""
...

@abc.abstractmethod
async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None:
async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent:
"""Wait some time, then take a peer worker out of busy state"""
...

Expand Down