From 3d547b07eafb5c354d04e2b2a288d69134a07ce7 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 26 Oct 2022 15:15:41 +0200 Subject: [PATCH] fix comms test --- distributed/shuffle/_multi_comm.py | 17 +++++----- distributed/shuffle/_shuffle_extension.py | 2 -- distributed/shuffle/tests/test_multi_comm.py | 34 ++++++++++++++++++-- 3 files changed, 40 insertions(+), 13 deletions(-) diff --git a/distributed/shuffle/_multi_comm.py b/distributed/shuffle/_multi_comm.py index 75455c67907..b508442cff5 100644 --- a/distributed/shuffle/_multi_comm.py +++ b/distributed/shuffle/_multi_comm.py @@ -169,6 +169,7 @@ async def communicate(self) -> None: task = asyncio.create_task(self._process(address, shards, size)) del shards self._futures.add(task) + task.add_done_callback(self._futures.discard) async def _process(self, address: str, shards: list, size: int) -> None: """Send one message off to a neighboring worker""" @@ -204,7 +205,10 @@ async def _process(self, address: str, shards: list, size: int) -> None: async def flush(self) -> None: """ - We don't expect any more data, wait until everything is flushed through + We don't expect any more data, wait until everything is flushed through. + + Not thread safe. Caller must ensure this is not called concurrently with + put """ if self._exception: await self._communicate_task @@ -213,23 +217,18 @@ async def flush(self) -> None: while self.shards: await asyncio.sleep(0.05) - # FIXME: This needs lock protection to guarantee that shards are indeed - # empty and all associated futures are in self._futures but _process is - # locking as well. Either we'll need a RLock or a second lock await asyncio.gather(*self._futures) self._futures.clear() - assert not self.total_size + if self.total_size: + raise RuntimeError("Received additional input after flushing.") self._done = True await self._communicate_task async def close(self) -> None: - # TODO: Should we flush here? - # TODO: Should this raise an exception if there is one? - # TODO: Should finished tasks remove themselves from futures, s.t. we only raise once. We do not raise multiple times, do we? - await self.flush() self._done = True await self._communicate_task + await asyncio.gather(*self._futures) @contextlib.contextmanager def time(self, name: str) -> Iterator[None]: diff --git a/distributed/shuffle/_shuffle_extension.py b/distributed/shuffle/_shuffle_extension.py index 72763107b31..e1eb51a3c5e 100644 --- a/distributed/shuffle/_shuffle_extension.py +++ b/distributed/shuffle/_shuffle_extension.py @@ -592,8 +592,6 @@ def get( mapping[part] = worker output_workers.add(worker) self.scheduler.set_restrictions({ts.key: {worker}}) - # ts.worker_restrictions = {worker} # TODO: once cython is - # gone self.worker_for[id] = mapping self.schemas[id] = schema diff --git a/distributed/shuffle/tests/test_multi_comm.py b/distributed/shuffle/tests/test_multi_comm.py index 90919511701..f53fb8a6d65 100644 --- a/distributed/shuffle/tests/test_multi_comm.py +++ b/distributed/shuffle/tests/test_multi_comm.py @@ -4,6 +4,7 @@ from collections import defaultdict import pytest +from tornado.ioloop import IOLoop from distributed.shuffle._multi_comm import MultiComm from distributed.utils_test import gen_test @@ -16,7 +17,7 @@ async def test_basic(tmp_path): async def send(address, shards): d[address].extend(shards) - mc = MultiComm(send=send) + mc = MultiComm(send=send, loop=IOLoop.current()) mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) @@ -33,7 +34,7 @@ async def test_exceptions(tmp_path): async def send(address, shards): raise Exception(123) - mc = MultiComm(send=send) + mc = MultiComm(send=send, loop=IOLoop.current()) mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]}) while not mc._exception: @@ -44,3 +45,32 @@ async def send(address, shards): with pytest.raises(Exception, match="123"): await mc.flush() + + await mc.close() + + +@gen_test() +async def test_slow_send(tmpdir): + block_send = asyncio.Event() + block_send.set() + sending_first = asyncio.Event() + d = defaultdict(list) + + async def send(address, shards): + await block_send.wait() + d[address].extend(shards) + sending_first.set() + + mc = MultiComm(send=send, loop=IOLoop.current()) + mc.max_connections = 1 + mc.put({"x": [b"0"], "y": [b"1"]}) + mc.put({"x": [b"0"], "y": [b"1"]}) + flush_task = asyncio.create_task(mc.flush()) + await sending_first.wait() + block_send.clear() + + with pytest.raises(RuntimeError): + mc.put({"x": [b"2"], "y": [b"2"]}) + await flush_task + + assert [b"2" not in shard for shard in d["x"]]