Skip to content

Commit

Permalink
fix comms test
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Oct 26, 2022
1 parent ffb3d08 commit 3d547b0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
17 changes: 8 additions & 9 deletions distributed/shuffle/_multi_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ async def communicate(self) -> None:
task = asyncio.create_task(self._process(address, shards, size))
del shards
self._futures.add(task)
task.add_done_callback(self._futures.discard)

async def _process(self, address: str, shards: list, size: int) -> None:
"""Send one message off to a neighboring worker"""
Expand Down Expand Up @@ -204,7 +205,10 @@ async def _process(self, address: str, shards: list, size: int) -> None:

async def flush(self) -> None:
"""
We don't expect any more data, wait until everything is flushed through
We don't expect any more data, wait until everything is flushed through.
Not thread safe. Caller must ensure this is not called concurrently with
put
"""
if self._exception:
await self._communicate_task
Expand All @@ -213,23 +217,18 @@ async def flush(self) -> None:

while self.shards:
await asyncio.sleep(0.05)
# FIXME: This needs lock protection to guarantee that shards are indeed
# empty and all associated futures are in self._futures but _process is
# locking as well. Either we'll need a RLock or a second lock
await asyncio.gather(*self._futures)
self._futures.clear()

assert not self.total_size
if self.total_size:
raise RuntimeError("Received additional input after flushing.")
self._done = True
await self._communicate_task

async def close(self) -> None:
# TODO: Should we flush here?
# TODO: Should this raise an exception if there is one?
# TODO: Should finished tasks remove themselves from futures, s.t. we only raise once. We do not raise multiple times, do we?
await self.flush()
self._done = True
await self._communicate_task
await asyncio.gather(*self._futures)

@contextlib.contextmanager
def time(self, name: str) -> Iterator[None]:
Expand Down
2 changes: 0 additions & 2 deletions distributed/shuffle/_shuffle_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,6 @@ def get(
mapping[part] = worker
output_workers.add(worker)
self.scheduler.set_restrictions({ts.key: {worker}})
# ts.worker_restrictions = {worker} # TODO: once cython is
# gone

self.worker_for[id] = mapping
self.schemas[id] = schema
Expand Down
34 changes: 32 additions & 2 deletions distributed/shuffle/tests/test_multi_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict

import pytest
from tornado.ioloop import IOLoop

from distributed.shuffle._multi_comm import MultiComm
from distributed.utils_test import gen_test
Expand All @@ -16,7 +17,7 @@ async def test_basic(tmp_path):
async def send(address, shards):
d[address].extend(shards)

mc = MultiComm(send=send)
mc = MultiComm(send=send, loop=IOLoop.current())
mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]})
mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]})

Expand All @@ -33,7 +34,7 @@ async def test_exceptions(tmp_path):
async def send(address, shards):
raise Exception(123)

mc = MultiComm(send=send)
mc = MultiComm(send=send, loop=IOLoop.current())
mc.put({"x": [b"0" * 1000], "y": [b"1" * 500]})

while not mc._exception:
Expand All @@ -44,3 +45,32 @@ async def send(address, shards):

with pytest.raises(Exception, match="123"):
await mc.flush()

await mc.close()


@gen_test()
async def test_slow_send(tmpdir):
block_send = asyncio.Event()
block_send.set()
sending_first = asyncio.Event()
d = defaultdict(list)

async def send(address, shards):
await block_send.wait()
d[address].extend(shards)
sending_first.set()

mc = MultiComm(send=send, loop=IOLoop.current())
mc.max_connections = 1
mc.put({"x": [b"0"], "y": [b"1"]})
mc.put({"x": [b"0"], "y": [b"1"]})
flush_task = asyncio.create_task(mc.flush())
await sending_first.wait()
block_send.clear()

with pytest.raises(RuntimeError):
mc.put({"x": [b"2"], "y": [b"2"]})
await flush_task

assert [b"2" not in shard for shard in d["x"]]

0 comments on commit 3d547b0

Please sign in to comment.