Skip to content

Commit

Permalink
Don't heartbeat while Worker is closing (#6543)
Browse files Browse the repository at this point in the history
  • Loading branch information
gjoseph92 authored Jun 9, 2022
1 parent 9b8172b commit 879fb89
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 43 deletions.
41 changes: 41 additions & 0 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
import dask.config

from distributed import Client, Nanny, Scheduler, Worker, config, default_client
from distributed.batched import BatchedSend
from distributed.comm.core import connect
from distributed.compatibility import WINDOWS
from distributed.core import Server, Status, rpc
from distributed.metrics import time
from distributed.tests.test_batched import EchoServer
from distributed.utils import mp_context
from distributed.utils_test import (
_LockedCommPool,
Expand All @@ -27,6 +30,7 @@
check_process_leak,
cluster,
dump_cluster_state,
freeze_batched_send,
gen_cluster,
gen_test,
inc,
Expand Down Expand Up @@ -781,3 +785,40 @@ async def test(s):
with pytest.raises(CustomError):
test()
assert test_done


@gen_test()
async def test_freeze_batched_send():
async with EchoServer() as e:
comm = await connect(e.address)
b = BatchedSend(interval=0)
b.start(comm)

b.send("hello")
assert await comm.read() == ("hello",)

with freeze_batched_send(b) as locked_comm:
b.send("foo")
b.send("bar")

# Sent messages are available on the write queue
msg = await locked_comm.write_queue.get()
assert msg == (comm.peer_address, ["foo", "bar"])

# Sent messages will not reach the echo server
await asyncio.sleep(0.01)
assert e.count == 1

# Now we let messages send to the echo server
locked_comm.write_event.set()
assert await comm.read() == ("foo", "bar")
assert e.count == 2

locked_comm.write_event.clear()
b.send("baz")
await asyncio.sleep(0.01)
assert e.count == 2

assert b.comm is comm
assert await comm.read() == ("baz",)
assert e.count == 3
59 changes: 19 additions & 40 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
captured_logger,
dec,
div,
freeze_batched_send,
gen_cluster,
gen_test,
inc,
Expand Down Expand Up @@ -1751,59 +1752,37 @@ async def test_heartbeat_missing_real_cluster(s, a):
# However, `Scheduler.remove_worker` and `Worker.close` both currently leave things
# in degenerate, half-closed states while they're running (and yielding control
# via `await`).
# When https://github.com/dask/distributed/issues/6390 is fixed, this should no
# longer be possible.

# Currently this is easy because of https://github.com/dask/distributed/issues/6354.
# But even with that fixed, it may still be possible, since `Worker.close`
# could take an arbitrarily long time, and things can keep running
# while it's closing.
assumption_msg = "Test assumptions have changed. Race condition may have been fixed; this test may be removable."

class BlockCloseExtension:
def __init__(self) -> None:
self.close_reached = asyncio.Event()
self.unblock_close = asyncio.Event()

async def close(self):
self.close_reached.set()
await self.unblock_close.wait()

# `Worker.close` awaits extensions' `close` methods midway though.
# During this `await`, the Worker is in state `closing`, but the heartbeat
# `PeriodicCallback` is still running. We will intentionally pause
# the worker here to simulate the timing of a heartbeat happing in this
# degenerate state.
a.extensions["block-close"] = block_close = BlockCloseExtension()

with captured_logger(
"distributed.worker", level=logging.WARNING
) as wlogger, captured_logger(
"distributed.scheduler", level=logging.WARNING
) as slogger:
await s.remove_worker(a.address, stimulus_id="foo")
assert not s.workers
with freeze_batched_send(s.stream_comms[a.address]):
await s.remove_worker(a.address, stimulus_id="foo")
assert not s.workers

# Wait until the close signal reaches the worker and it starts shutting down.
await block_close.close_reached.wait()
assert a.status == Status.closing, assumption_msg
assert a.periodic_callbacks["heartbeat"].is_running(), assumption_msg
# The heartbeat PeriodicCallback is still running, so one _could_ fire
# while `Worker.close` has yielded control. We simulate that explicitly.
# The scheduler has removed the worker state, but the close message has
# not reached the worker yet.
assert a.status == Status.running, assumption_msg
assert a.periodic_callbacks["heartbeat"].is_running(), assumption_msg

# Because `hearbeat` will `await self.close`, which is blocking on our
# extension, we have to run it concurrently.
hbt = asyncio.create_task(a.heartbeat())
# The heartbeat PeriodicCallback is still running, so one _could_ fire
# before the `op: close` message reaches the worker. We simulate that explicitly.
await a.heartbeat()

# Worker was already closing, so the second `.close()` will be idempotent.
# Best we can test for is this log message.
while "Scheduler was unaware of this worker" not in wlogger.getvalue():
await asyncio.sleep(0.01)
# The heartbeat receives a `status: missing` from the scheduler, so it
# closes the worker. Heartbeats aren't sent over batched comms, so
# `freeze_batched_send` doesn't affect them.
assert a.status == Status.closed

assert "Received heartbeat from unregistered worker" in slogger.getvalue()
assert not s.workers
assert "Scheduler was unaware of this worker" in wlogger.getvalue()
assert "Received heartbeat from unregistered worker" in slogger.getvalue()

block_close.unblock_close.set()
await hbt
await a.finished()
assert not s.workers


Expand Down
31 changes: 31 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from distributed import Scheduler, system
from distributed import versions as version_module
from distributed.batched import BatchedSend
from distributed.client import Client, _global_clients, default_client
from distributed.comm import Comm
from distributed.comm.tcp import TCP
Expand Down Expand Up @@ -2345,3 +2346,33 @@ def freeze_data_fetching(w: Worker, *, jump_start: bool = False):
if jump_start:
w.status = Status.paused
w.status = Status.running


@contextmanager
def freeze_batched_send(bcomm: BatchedSend) -> Iterator[LockedComm]:
"""
Contextmanager blocking writes to a `BatchedSend` from sending over the network.
The returned `LockedComm` object can be used for control flow and inspection via its
``read_event``, ``read_queue``, ``write_event``, and ``write_queue`` attributes.
On exit, any writes that were blocked are un-blocked, and the original comm of the
`BatchedSend` is restored.
"""
assert not bcomm.closed()
assert bcomm.comm
assert not bcomm.comm.closed()
orig_comm = bcomm.comm

write_event = asyncio.Event()
write_queue: asyncio.Queue = asyncio.Queue()

bcomm.comm = locked_comm = LockedComm(
orig_comm, None, None, write_event, write_queue
)

try:
yield locked_comm
finally:
write_event.set()
bcomm.comm = orig_comm
8 changes: 5 additions & 3 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,11 @@ async def close(
logger.info("Not waiting on executor to close")
self.status = Status.closing

# Stop callbacks before giving up control in any `await`.
# We don't want to heartbeat while closing.
for pc in self.periodic_callbacks.values():
pc.stop()

if self._async_instructions:
for task in self._async_instructions:
task.cancel()
Expand Down Expand Up @@ -1561,9 +1566,6 @@ async def close(

await asyncio.gather(*(td for td in teardowns if isawaitable(td)))

for pc in self.periodic_callbacks.values():
pc.stop()

if self._client:
# If this worker is the last one alive, clean up the worker
# initialized clients
Expand Down

0 comments on commit 879fb89

Please sign in to comment.