Skip to content

Commit

Permalink
Ensure multicomm does not deadlock
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Oct 28, 2022
1 parent 8e8556d commit 62c018b
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 14 deletions.
44 changes: 31 additions & 13 deletions distributed/shuffle/_multi_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import weakref
from collections import defaultdict
from collections.abc import Iterator
from typing import Awaitable, Callable, Sequence
from typing import Any, Awaitable, Callable, Sequence

from tornado.ioloop import IOLoop

Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
self.total_size = 0
self.total_moved = 0
self._wait_on_memory = threading.Condition()
self._futures: set[asyncio.Task] = set()
self._tasks: set[asyncio.Task] = set()
self._done = False
self.diagnostics: dict[str, float] = defaultdict(float)
self._loop = loop
Expand Down Expand Up @@ -119,6 +119,8 @@ def put(self, data: dict[str, Sequence[bytes]]) -> None:
while MultiComm.total_size > MultiComm.memory_limit:
with self.time("waiting-on-memory"):
with self._wait_on_memory:
if self._exception:
raise self._exception
self._wait_on_memory.wait(1) # Block until memory calms down

async def communicate(self) -> None:
Expand All @@ -138,6 +140,7 @@ async def communicate(self) -> None:

while not self._done:
with self.time("idle"):
await self._maybe_raise_exception()
if not self.shards:
await asyncio.sleep(0.1)
continue
Expand Down Expand Up @@ -168,8 +171,8 @@ async def communicate(self) -> None:
assert shards
task = asyncio.create_task(self._process(address, shards, size))
del shards
self._futures.add(task)
task.add_done_callback(self._futures.discard)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)

async def _process(self, address: str, shards: list, size: int) -> None:
"""Send one message off to a neighboring worker"""
Expand All @@ -178,8 +181,6 @@ async def _process(self, address: str, shards: list, size: int) -> None:
# Consider boosting total_size a bit here to account for duplication

try:
# while (time.time() // 5 % 4) == 0:
# await asyncio.sleep(0.1)
start = time.time()
try:
with self.time("send"):
Expand All @@ -201,7 +202,14 @@ async def _process(self, address: str, shards: list, size: int) -> None:
MultiComm.total_size -= size
with self._wait_on_memory:
self._wait_on_memory.notify()
await self.queue.put(None)
self.queue.put_nowait(None)

async def _maybe_raise_exception(self) -> None:
if self._exception:
assert self._done
await self._communicate_task
await asyncio.gather(*self._tasks)
raise self._exception

async def flush(self) -> None:
"""
Expand All @@ -211,24 +219,34 @@ async def flush(self) -> None:
put
"""
if self._exception:
await self._communicate_task
await asyncio.gather(*self._futures)
raise self._exception
await self._maybe_raise_exception()

while self.shards:
await self._maybe_raise_exception()
await asyncio.sleep(0.05)
await asyncio.gather(*self._futures)
self._futures.clear()

await asyncio.gather(*self._tasks)
await self._maybe_raise_exception()

if self.total_size:
raise RuntimeError("Received additional input after flushing.")
self._done = True
await self._communicate_task

async def __aenter__(self) -> MultiComm:
return self

async def __aexit__(self, exc: Any, typ: Any, traceback: Any) -> None:
await self.close()

async def close(self) -> None:
if not self._done:
await self.flush()
self._done = True
await self._communicate_task
await asyncio.gather(*self._futures)
await asyncio.gather(*self._tasks)
self.shards.clear()
self.sizes.clear()

@contextlib.contextmanager
def time(self, name: str) -> Iterator[None]:
Expand Down
1 change: 1 addition & 0 deletions distributed/shuffle/_multi_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ async def communicate(self) -> None:

while not self._done:
with self.time("idle"):
await self._maybe_raise_exception()
if not self.shards:
await asyncio.sleep(0.1)
continue
Expand Down
70 changes: 70 additions & 0 deletions distributed/shuffle/tests/test_multi_comm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
import concurrent.futures
import math
from collections import defaultdict

import pytest
Expand Down Expand Up @@ -74,3 +76,71 @@ async def send(address, shards):
await flush_task

assert [b"2" not in shard for shard in d["x"]]


def gen_bytes(percentage: float) -> bytes:
num_bytes = int(math.floor(percentage * MultiComm.memory_limit))
return b"0" * num_bytes


@pytest.mark.parametrize("explicit_flush", [True, False])
@gen_test()
async def test_concurrent_puts(explicit_flush):
d = defaultdict(list)

async def send(address, shards):
d[address].extend(shards)

frac = 0.1
nshards = 10
nputs = 20
payload = {x: [gen_bytes(frac)] for x in range(nshards)}
with concurrent.futures.ThreadPoolExecutor(
2, thread_name_prefix="test IOLoop"
) as tpe:
async with MultiComm(send=send, loop=IOLoop.current()) as mc:
loop = asyncio.get_running_loop()
futs = [loop.run_in_executor(tpe, mc.put, payload) for _ in range(nputs)]

await asyncio.gather(*futs)
if explicit_flush:
await mc.flush()

assert not mc.shards
assert not mc.sizes

assert not mc.shards
assert not mc.sizes
assert len(d) == 10
assert sum(map(len, d[0])) == len(gen_bytes(frac)) * nputs


@gen_test()
async def test_concurrent_puts_error():
d = defaultdict(list)

counter = 0

async def send(address, shards):
nonlocal counter
counter += 1
if counter == 5:
raise OSError("error during send")
d[address].extend(shards)

frac = 0.1
nshards = 10
nputs = 20
payload = {x: [gen_bytes(frac)] for x in range(nshards)}
with concurrent.futures.ThreadPoolExecutor(
2, thread_name_prefix="test IOLoop"
) as tpe:
async with MultiComm(send=send, loop=IOLoop.current()) as mc:
loop = asyncio.get_running_loop()
futs = [loop.run_in_executor(tpe, mc.put, payload) for _ in range(nputs)]

with pytest.raises(OSError, match="error during send"):
await asyncio.gather(*futs)

assert not mc.shards
assert not mc.sizes
2 changes: 1 addition & 1 deletion distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async def test_bad_disk(c, s, a, b):
# clean_scheduler(s)


@pytest.mark.xfail
@pytest.mark.skip
@pytest.mark.slow
@gen_cluster(client=True)
async def test_crashed_worker(c, s, a, b):
Expand Down

0 comments on commit 62c018b

Please sign in to comment.