From 56b2435e55d70007659f24d91dcd947eea245a64 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 8 Jun 2022 18:47:24 -0600 Subject: [PATCH 1/3] Add `freeze_batched_send` helper --- distributed/tests/test_utils_test.py | 41 ++++++++++++++++++++++++++++ distributed/utils_test.py | 31 +++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 62ccc544943..745ae59caa9 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -15,9 +15,12 @@ import dask.config from distributed import Client, Nanny, Scheduler, Worker, config, default_client +from distributed.batched import BatchedSend +from distributed.comm.core import connect from distributed.compatibility import WINDOWS from distributed.core import Server, Status, rpc from distributed.metrics import time +from distributed.tests.test_batched import EchoServer from distributed.utils import mp_context from distributed.utils_test import ( _LockedCommPool, @@ -27,6 +30,7 @@ check_process_leak, cluster, dump_cluster_state, + freeze_batched_send, gen_cluster, gen_test, inc, @@ -781,3 +785,40 @@ async def test(s): with pytest.raises(CustomError): test() assert test_done + + +@gen_test() +async def test_freeze_batched_send(): + async with EchoServer() as e: + comm = await connect(e.address) + b = BatchedSend(interval=0) + b.start(comm) + + b.send("hello") + assert await comm.read() == ("hello",) + + with freeze_batched_send(b) as locked_comm: + b.send("foo") + b.send("bar") + + # Sent messages are available on the write queue + msg = await locked_comm.write_queue.get() + assert msg == (comm.peer_address, ["foo", "bar"]) + + # Sent messages will not reach the echo server + await asyncio.sleep(0.01) + assert e.count == 1 + + # Now we let messages send to the echo server + locked_comm.write_event.set() + assert await comm.read() == ("foo", "bar") + assert e.count == 2 + + locked_comm.write_event.clear() + b.send("baz") + await asyncio.sleep(0.01) + assert e.count == 2 + + assert b.comm is comm + assert await comm.read() == ("baz",) + assert e.count == 3 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index a10fbcd7870..644d3339291 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -37,6 +37,7 @@ from distributed import Scheduler, system from distributed import versions as version_module +from distributed.batched import BatchedSend from distributed.client import Client, _global_clients, default_client from distributed.comm import Comm from distributed.comm.tcp import TCP @@ -2345,3 +2346,33 @@ def freeze_data_fetching(w: Worker, *, jump_start: bool = False): if jump_start: w.status = Status.paused w.status = Status.running + + +@contextmanager +def freeze_batched_send(bcomm: BatchedSend) -> Iterator[LockedComm]: + """ + Contextmanager blocking writes to a `BatchedSend` from sending over the network. + + The returned `LockedComm` object can be used for control flow and inspection via its + ``read_event``, ``read_queue``, ``write_event``, and ``write_queue`` attributes. + + On exit, any writes that were blocked are un-blocked, and the original comm of the + `BatchedSend` is restored. + """ + assert not bcomm.closed() + assert bcomm.comm + assert not bcomm.comm.closed() + orig_comm = bcomm.comm + + write_event = asyncio.Event() + write_queue: asyncio.Queue = asyncio.Queue() + + bcomm.comm = locked_comm = LockedComm( + orig_comm, None, None, write_event, write_queue + ) + + try: + yield locked_comm + finally: + write_event.set() + bcomm.comm = orig_comm From dd779cd7d57375d8af173bbb3ba2594a6102461d Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 8 Jun 2022 18:50:56 -0600 Subject: [PATCH 2/3] Update `test_heartbeat_missing_real_cluster` We're going to fix the race condition we were exploiting here, but a simpler one still exists. --- distributed/tests/test_worker.py | 59 ++++++++++---------------------- 1 file changed, 19 insertions(+), 40 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 1ace95dc605..e4c1cab17d4 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -56,6 +56,7 @@ captured_logger, dec, div, + freeze_batched_send, gen_cluster, gen_test, inc, @@ -1751,59 +1752,37 @@ async def test_heartbeat_missing_real_cluster(s, a): # However, `Scheduler.remove_worker` and `Worker.close` both currently leave things # in degenerate, half-closed states while they're running (and yielding control # via `await`). + # When https://github.com/dask/distributed/issues/6390 is fixed, this should no + # longer be possible. - # Currently this is easy because of https://github.com/dask/distributed/issues/6354. - # But even with that fixed, it may still be possible, since `Worker.close` - # could take an arbitrarily long time, and things can keep running - # while it's closing. assumption_msg = "Test assumptions have changed. Race condition may have been fixed; this test may be removable." - class BlockCloseExtension: - def __init__(self) -> None: - self.close_reached = asyncio.Event() - self.unblock_close = asyncio.Event() - - async def close(self): - self.close_reached.set() - await self.unblock_close.wait() - - # `Worker.close` awaits extensions' `close` methods midway though. - # During this `await`, the Worker is in state `closing`, but the heartbeat - # `PeriodicCallback` is still running. We will intentionally pause - # the worker here to simulate the timing of a heartbeat happing in this - # degenerate state. - a.extensions["block-close"] = block_close = BlockCloseExtension() - with captured_logger( "distributed.worker", level=logging.WARNING ) as wlogger, captured_logger( "distributed.scheduler", level=logging.WARNING ) as slogger: - await s.remove_worker(a.address, stimulus_id="foo") - assert not s.workers + with freeze_batched_send(s.stream_comms[a.address]): + await s.remove_worker(a.address, stimulus_id="foo") + assert not s.workers - # Wait until the close signal reaches the worker and it starts shutting down. - await block_close.close_reached.wait() - assert a.status == Status.closing, assumption_msg - assert a.periodic_callbacks["heartbeat"].is_running(), assumption_msg - # The heartbeat PeriodicCallback is still running, so one _could_ fire - # while `Worker.close` has yielded control. We simulate that explicitly. + # The scheduler has removed the worker state, but the close message has + # not reached the worker yet. + assert a.status == Status.running, assumption_msg + assert a.periodic_callbacks["heartbeat"].is_running(), assumption_msg - # Because `hearbeat` will `await self.close`, which is blocking on our - # extension, we have to run it concurrently. - hbt = asyncio.create_task(a.heartbeat()) + # The heartbeat PeriodicCallback is still running, so one _could_ fire + # before the `op: close` message reaches the worker. We simulate that explicitly. + await a.heartbeat() - # Worker was already closing, so the second `.close()` will be idempotent. - # Best we can test for is this log message. - while "Scheduler was unaware of this worker" not in wlogger.getvalue(): - await asyncio.sleep(0.01) + # The heartbeat receives a `status: missing` from the scheduler, so it + # closes the worker. Heartbeats aren't sent over batched comms, so + # `freeze_batched_send` doesn't affect them. + assert a.status == Status.closed - assert "Received heartbeat from unregistered worker" in slogger.getvalue() - assert not s.workers + assert "Scheduler was unaware of this worker" in wlogger.getvalue() + assert "Received heartbeat from unregistered worker" in slogger.getvalue() - block_close.unblock_close.set() - await hbt - await a.finished() assert not s.workers From 20d3198f06dc3c7fe27734630596319a0bc65fe7 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 8 Jun 2022 19:10:06 -0600 Subject: [PATCH 3/3] Stop PCs first thing in `Worker.close` --- distributed/worker.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index d7598ad182a..184d1a76f5b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1524,6 +1524,11 @@ async def close( logger.info("Not waiting on executor to close") self.status = Status.closing + # Stop callbacks before giving up control in any `await`. + # We don't want to heartbeat while closing. + for pc in self.periodic_callbacks.values(): + pc.stop() + if self._async_instructions: for task in self._async_instructions: task.cancel() @@ -1561,9 +1566,6 @@ async def close( await asyncio.gather(*(td for td in teardowns if isawaitable(td))) - for pc in self.periodic_callbacks.values(): - pc.stop() - if self._client: # If this worker is the last one alive, clean up the worker # initialized clients