Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separation of concerns between Shuffle and WorkerShuffle #7195

Closed
wants to merge 13 commits into from
Prev Previous commit
Next Next commit
Ensure multicomm does not deadlock
fjetter committed Oct 28, 2022
commit 62c018b4f63a4ec446eebc40931973510e2d8b69
44 changes: 31 additions & 13 deletions distributed/shuffle/_multi_comm.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import weakref
from collections import defaultdict
from collections.abc import Iterator
from typing import Awaitable, Callable, Sequence
from typing import Any, Awaitable, Callable, Sequence

from tornado.ioloop import IOLoop

@@ -74,7 +74,7 @@ def __init__(
self.total_size = 0
self.total_moved = 0
self._wait_on_memory = threading.Condition()
self._futures: set[asyncio.Task] = set()
self._tasks: set[asyncio.Task] = set()
self._done = False
self.diagnostics: dict[str, float] = defaultdict(float)
self._loop = loop
@@ -119,6 +119,8 @@ def put(self, data: dict[str, Sequence[bytes]]) -> None:
while MultiComm.total_size > MultiComm.memory_limit:
with self.time("waiting-on-memory"):
with self._wait_on_memory:
if self._exception:
raise self._exception
self._wait_on_memory.wait(1) # Block until memory calms down

async def communicate(self) -> None:
@@ -138,6 +140,7 @@ async def communicate(self) -> None:

while not self._done:
with self.time("idle"):
await self._maybe_raise_exception()
if not self.shards:
await asyncio.sleep(0.1)
continue
@@ -168,8 +171,8 @@ async def communicate(self) -> None:
assert shards
task = asyncio.create_task(self._process(address, shards, size))
del shards
self._futures.add(task)
task.add_done_callback(self._futures.discard)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)

async def _process(self, address: str, shards: list, size: int) -> None:
"""Send one message off to a neighboring worker"""
@@ -178,8 +181,6 @@ async def _process(self, address: str, shards: list, size: int) -> None:
# Consider boosting total_size a bit here to account for duplication

try:
# while (time.time() // 5 % 4) == 0:
# await asyncio.sleep(0.1)
start = time.time()
try:
with self.time("send"):
@@ -201,7 +202,14 @@ async def _process(self, address: str, shards: list, size: int) -> None:
MultiComm.total_size -= size
with self._wait_on_memory:
self._wait_on_memory.notify()
await self.queue.put(None)
self.queue.put_nowait(None)

async def _maybe_raise_exception(self) -> None:
if self._exception:
assert self._done
await self._communicate_task
await asyncio.gather(*self._tasks)
raise self._exception

async def flush(self) -> None:
"""
@@ -211,24 +219,34 @@ async def flush(self) -> None:
put
"""
if self._exception:
await self._communicate_task
await asyncio.gather(*self._futures)
raise self._exception
await self._maybe_raise_exception()

while self.shards:
await self._maybe_raise_exception()
await asyncio.sleep(0.05)
await asyncio.gather(*self._futures)
self._futures.clear()

await asyncio.gather(*self._tasks)
await self._maybe_raise_exception()

if self.total_size:
raise RuntimeError("Received additional input after flushing.")
self._done = True
await self._communicate_task

async def __aenter__(self) -> MultiComm:
return self

async def __aexit__(self, exc: Any, typ: Any, traceback: Any) -> None:
await self.close()

async def close(self) -> None:
if not self._done:
await self.flush()
self._done = True
await self._communicate_task
await asyncio.gather(*self._futures)
await asyncio.gather(*self._tasks)
self.shards.clear()
self.sizes.clear()

@contextlib.contextmanager
def time(self, name: str) -> Iterator[None]:
1 change: 1 addition & 0 deletions distributed/shuffle/_multi_file.py
Original file line number Diff line number Diff line change
@@ -167,6 +167,7 @@ async def communicate(self) -> None:

while not self._done:
with self.time("idle"):
await self._maybe_raise_exception()
if not self.shards:
await asyncio.sleep(0.1)
continue
70 changes: 70 additions & 0 deletions distributed/shuffle/tests/test_multi_comm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
import concurrent.futures
import math
from collections import defaultdict

import pytest
@@ -74,3 +76,71 @@ async def send(address, shards):
await flush_task

assert [b"2" not in shard for shard in d["x"]]


def gen_bytes(percentage: float) -> bytes:
num_bytes = int(math.floor(percentage * MultiComm.memory_limit))
return b"0" * num_bytes


@pytest.mark.parametrize("explicit_flush", [True, False])
@gen_test()
async def test_concurrent_puts(explicit_flush):
d = defaultdict(list)

async def send(address, shards):
d[address].extend(shards)

frac = 0.1
nshards = 10
nputs = 20
payload = {x: [gen_bytes(frac)] for x in range(nshards)}
with concurrent.futures.ThreadPoolExecutor(
2, thread_name_prefix="test IOLoop"
) as tpe:
async with MultiComm(send=send, loop=IOLoop.current()) as mc:
loop = asyncio.get_running_loop()
futs = [loop.run_in_executor(tpe, mc.put, payload) for _ in range(nputs)]

await asyncio.gather(*futs)
if explicit_flush:
await mc.flush()

assert not mc.shards
assert not mc.sizes

assert not mc.shards
assert not mc.sizes
assert len(d) == 10
assert sum(map(len, d[0])) == len(gen_bytes(frac)) * nputs


@gen_test()
async def test_concurrent_puts_error():
d = defaultdict(list)

counter = 0

async def send(address, shards):
nonlocal counter
counter += 1
if counter == 5:
raise OSError("error during send")
d[address].extend(shards)

frac = 0.1
nshards = 10
nputs = 20
payload = {x: [gen_bytes(frac)] for x in range(nshards)}
with concurrent.futures.ThreadPoolExecutor(
2, thread_name_prefix="test IOLoop"
) as tpe:
async with MultiComm(send=send, loop=IOLoop.current()) as mc:
loop = asyncio.get_running_loop()
futs = [loop.run_in_executor(tpe, mc.put, payload) for _ in range(nputs)]

with pytest.raises(OSError, match="error during send"):
await asyncio.gather(*futs)

assert not mc.shards
assert not mc.sizes
2 changes: 1 addition & 1 deletion distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
@@ -120,7 +120,7 @@ async def test_bad_disk(c, s, a, b):
# clean_scheduler(s)


@pytest.mark.xfail
@pytest.mark.skip
@pytest.mark.slow
@gen_cluster(client=True)
async def test_crashed_worker(c, s, a, b):