diff --git a/distributed/shuffle/_multi_comm.py b/distributed/shuffle/_multi_comm.py index b508442cff5..6047e5a8717 100644 --- a/distributed/shuffle/_multi_comm.py +++ b/distributed/shuffle/_multi_comm.py @@ -7,7 +7,7 @@ import weakref from collections import defaultdict from collections.abc import Iterator -from typing import Awaitable, Callable, Sequence +from typing import Any, Awaitable, Callable, Sequence from tornado.ioloop import IOLoop @@ -74,7 +74,7 @@ def __init__( self.total_size = 0 self.total_moved = 0 self._wait_on_memory = threading.Condition() - self._futures: set[asyncio.Task] = set() + self._tasks: set[asyncio.Task] = set() self._done = False self.diagnostics: dict[str, float] = defaultdict(float) self._loop = loop @@ -119,6 +119,8 @@ def put(self, data: dict[str, Sequence[bytes]]) -> None: while MultiComm.total_size > MultiComm.memory_limit: with self.time("waiting-on-memory"): with self._wait_on_memory: + if self._exception: + raise self._exception self._wait_on_memory.wait(1) # Block until memory calms down async def communicate(self) -> None: @@ -138,6 +140,7 @@ async def communicate(self) -> None: while not self._done: with self.time("idle"): + await self._maybe_raise_exception() if not self.shards: await asyncio.sleep(0.1) continue @@ -168,8 +171,8 @@ async def communicate(self) -> None: assert shards task = asyncio.create_task(self._process(address, shards, size)) del shards - self._futures.add(task) - task.add_done_callback(self._futures.discard) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) async def _process(self, address: str, shards: list, size: int) -> None: """Send one message off to a neighboring worker""" @@ -178,8 +181,6 @@ async def _process(self, address: str, shards: list, size: int) -> None: # Consider boosting total_size a bit here to account for duplication try: - # while (time.time() // 5 % 4) == 0: - # await asyncio.sleep(0.1) start = time.time() try: with self.time("send"): @@ -201,7 +202,14 @@ async def _process(self, address: str, shards: list, size: int) -> None: MultiComm.total_size -= size with self._wait_on_memory: self._wait_on_memory.notify() - await self.queue.put(None) + self.queue.put_nowait(None) + + async def _maybe_raise_exception(self) -> None: + if self._exception: + assert self._done + await self._communicate_task + await asyncio.gather(*self._tasks) + raise self._exception async def flush(self) -> None: """ @@ -211,24 +219,34 @@ async def flush(self) -> None: put """ if self._exception: - await self._communicate_task - await asyncio.gather(*self._futures) - raise self._exception + await self._maybe_raise_exception() while self.shards: + await self._maybe_raise_exception() await asyncio.sleep(0.05) - await asyncio.gather(*self._futures) - self._futures.clear() + + await asyncio.gather(*self._tasks) + await self._maybe_raise_exception() if self.total_size: raise RuntimeError("Received additional input after flushing.") self._done = True await self._communicate_task + async def __aenter__(self) -> MultiComm: + return self + + async def __aexit__(self, exc: Any, typ: Any, traceback: Any) -> None: + await self.close() + async def close(self) -> None: + if not self._done: + await self.flush() self._done = True await self._communicate_task - await asyncio.gather(*self._futures) + await asyncio.gather(*self._tasks) + self.shards.clear() + self.sizes.clear() @contextlib.contextmanager def time(self, name: str) -> Iterator[None]: diff --git a/distributed/shuffle/_multi_file.py b/distributed/shuffle/_multi_file.py index 9fa3c27d300..a92e87280fe 100644 --- a/distributed/shuffle/_multi_file.py +++ b/distributed/shuffle/_multi_file.py @@ -167,6 +167,7 @@ async def communicate(self) -> None: while not self._done: with self.time("idle"): + await self._maybe_raise_exception() if not self.shards: await asyncio.sleep(0.1) continue diff --git a/distributed/shuffle/tests/test_multi_comm.py b/distributed/shuffle/tests/test_multi_comm.py index f53fb8a6d65..dc5c18f3f64 100644 --- a/distributed/shuffle/tests/test_multi_comm.py +++ b/distributed/shuffle/tests/test_multi_comm.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +import concurrent.futures +import math from collections import defaultdict import pytest @@ -74,3 +76,71 @@ async def send(address, shards): await flush_task assert [b"2" not in shard for shard in d["x"]] + + +def gen_bytes(percentage: float) -> bytes: + num_bytes = int(math.floor(percentage * MultiComm.memory_limit)) + return b"0" * num_bytes + + +@pytest.mark.parametrize("explicit_flush", [True, False]) +@gen_test() +async def test_concurrent_puts(explicit_flush): + d = defaultdict(list) + + async def send(address, shards): + d[address].extend(shards) + + frac = 0.1 + nshards = 10 + nputs = 20 + payload = {x: [gen_bytes(frac)] for x in range(nshards)} + with concurrent.futures.ThreadPoolExecutor( + 2, thread_name_prefix="test IOLoop" + ) as tpe: + async with MultiComm(send=send, loop=IOLoop.current()) as mc: + loop = asyncio.get_running_loop() + futs = [loop.run_in_executor(tpe, mc.put, payload) for _ in range(nputs)] + + await asyncio.gather(*futs) + if explicit_flush: + await mc.flush() + + assert not mc.shards + assert not mc.sizes + + assert not mc.shards + assert not mc.sizes + assert len(d) == 10 + assert sum(map(len, d[0])) == len(gen_bytes(frac)) * nputs + + +@gen_test() +async def test_concurrent_puts_error(): + d = defaultdict(list) + + counter = 0 + + async def send(address, shards): + nonlocal counter + counter += 1 + if counter == 5: + raise OSError("error during send") + d[address].extend(shards) + + frac = 0.1 + nshards = 10 + nputs = 20 + payload = {x: [gen_bytes(frac)] for x in range(nshards)} + with concurrent.futures.ThreadPoolExecutor( + 2, thread_name_prefix="test IOLoop" + ) as tpe: + async with MultiComm(send=send, loop=IOLoop.current()) as mc: + loop = asyncio.get_running_loop() + futs = [loop.run_in_executor(tpe, mc.put, payload) for _ in range(nputs)] + + with pytest.raises(OSError, match="error during send"): + await asyncio.gather(*futs) + + assert not mc.shards + assert not mc.sizes diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index db94600b7e3..383431e2834 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -120,7 +120,7 @@ async def test_bad_disk(c, s, a, b): # clean_scheduler(s) -@pytest.mark.xfail +@pytest.mark.skip @pytest.mark.slow @gen_cluster(client=True) async def test_crashed_worker(c, s, a, b):