Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separation of concerns between Shuffle and WorkerShuffle #7195

Closed
wants to merge 13 commits into from
3 changes: 0 additions & 3 deletions distributed/dashboard/tests/test_scheduler_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,9 +1126,6 @@ async def test_shuffling(c, s, a, b):
ss.update()
await asyncio.sleep(0.1)
assert time() < start + 5
# FIXME: If this is still running while the test is running, this raises
# awkward CancelledErrors
await df2


@gen_cluster(client=True, scheduler_kwargs={"dashboard": True}, timeout=60)
Expand Down
78 changes: 54 additions & 24 deletions distributed/shuffle/_multi_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import weakref
from collections import defaultdict
from collections.abc import Iterator
from typing import Awaitable, Callable, Sequence
from typing import Any, Awaitable, Callable, Sequence

from tornado.ioloop import IOLoop

from dask.utils import parse_bytes

Expand Down Expand Up @@ -56,27 +58,26 @@ class MultiComm:
memory_limit = parse_bytes("100 MiB")
max_connections = 10
_queues: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, asyncio.Queue[None]
IOLoop, asyncio.Queue[None]
] = weakref.WeakKeyDictionary()
total_size = 0
lock = threading.Lock()

def __init__(
self,
send: Callable[[str, list[bytes]], Awaitable[None]],
loop: asyncio.AbstractEventLoop | None = None,
loop: IOLoop,
):
self.send = send
self.shards: dict[str, list[bytes]] = defaultdict(list)
self.sizes: dict[str, int] = defaultdict(int)
self.total_size = 0
self.total_moved = 0
self.thread_condition = threading.Condition()
self._futures: set[asyncio.Task] = set()
self._wait_on_memory = threading.Condition()
self._tasks: set[asyncio.Task] = set()
self._done = False
self.diagnostics: dict[str, float] = defaultdict(float)
self._loop = loop or asyncio.get_event_loop()

self._loop = loop
self._communicate_task = asyncio.create_task(self.communicate())
self._exception: Exception | None = None

Expand All @@ -102,6 +103,8 @@ def put(self, data: dict[str, Sequence[bytes]]) -> None:
"""
if self._exception:
raise self._exception
if self._done:
raise RuntimeError("Putting data on already closed comm.")
with self.lock:
for address, shards in data.items():
size = sum(map(len, shards))
Expand All @@ -115,8 +118,10 @@ def put(self, data: dict[str, Sequence[bytes]]) -> None:

while MultiComm.total_size > MultiComm.memory_limit:
with self.time("waiting-on-memory"):
with self.thread_condition:
self.thread_condition.wait(1) # Block until memory calms down
with self._wait_on_memory:
if self._exception:
raise self._exception
self._wait_on_memory.wait(1) # Block until memory calms down

async def communicate(self) -> None:
"""
Expand All @@ -135,6 +140,7 @@ async def communicate(self) -> None:

while not self._done:
with self.time("idle"):
await self._maybe_raise_exception()
if not self.shards:
await asyncio.sleep(0.1)
continue
Expand Down Expand Up @@ -163,26 +169,26 @@ async def communicate(self) -> None:

assert set(self.sizes) == set(self.shards)
assert shards
task = asyncio.create_task(self.process(address, shards, size))
task = asyncio.create_task(self._process(address, shards, size))
del shards
self._futures.add(task)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)

async def process(self, address: str, shards: list, size: int) -> None:
async def _process(self, address: str, shards: list, size: int) -> None:
"""Send one message off to a neighboring worker"""
with log_errors():

# Consider boosting total_size a bit here to account for duplication

try:
# while (time.time() // 5 % 4) == 0:
# await asyncio.sleep(0.1)
start = time.time()
try:
with self.time("send"):
await self.send(address, [b"".join(shards)])
except Exception as e:
self._exception = e
self._done = True
raise
stop = time.time()
self.diagnostics["avg_size"] = (
0.95 * self.diagnostics["avg_size"] + 0.05 * size
Expand All @@ -194,29 +200,53 @@ async def process(self, address: str, shards: list, size: int) -> None:
with self.lock:
self.total_size -= size
MultiComm.total_size -= size
with self.thread_condition:
self.thread_condition.notify()
await self.queue.put(None)
with self._wait_on_memory:
self._wait_on_memory.notify()
self.queue.put_nowait(None)

async def _maybe_raise_exception(self) -> None:
if self._exception:
assert self._done
await self._communicate_task
await asyncio.gather(*self._tasks)
raise self._exception

async def flush(self) -> None:
"""
We don't expect any more data, wait until everything is flushed through
We don't expect any more data, wait until everything is flushed through.

Not thread safe. Caller must ensure this is not called concurrently with
put
"""
if self._exception:
await self._communicate_task
await asyncio.gather(*self._futures)
raise self._exception
await self._maybe_raise_exception()

while self.shards:
await self._maybe_raise_exception()
await asyncio.sleep(0.05)

await asyncio.gather(*self._futures)
self._futures.clear()
await asyncio.gather(*self._tasks)
await self._maybe_raise_exception()

if self.total_size:
raise RuntimeError("Received additional input after flushing.")
self._done = True
await self._communicate_task

async def __aenter__(self) -> MultiComm:
return self

assert not self.total_size
async def __aexit__(self, exc: Any, typ: Any, traceback: Any) -> None:
await self.close()

async def close(self) -> None:
if not self._done:
await self.flush()
self._done = True
await self._communicate_task
await asyncio.gather(*self._tasks)
self.shards.clear()
self.sizes.clear()

@contextlib.contextmanager
def time(self, name: str) -> Iterator[None]:
Expand Down
91 changes: 56 additions & 35 deletions distributed/shuffle/_multi_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from collections.abc import Callable, Iterator
from typing import TYPE_CHECKING, Any, BinaryIO

from tornado.ioloop import IOLoop

from dask.sizeof import sizeof
from dask.utils import parse_bytes

Expand Down Expand Up @@ -62,17 +64,17 @@ class MultiFile:

shards: defaultdict[str, list]
sizes: defaultdict[str, int]
_futures: set[asyncio.Future]
_tasks: set[asyncio.Future]
diagnostics: defaultdict[str, float]
_exception: Exception | None

def __init__(
self,
directory: str,
loop: IOLoop,
dump: Callable[[Any, BinaryIO], None] = pickle.dump,
load: Callable[[BinaryIO], Any] = pickle.load,
sizeof: Callable[[list[pa.Table]], int] = sizeof,
loop: object = None,
):
self.directory = pathlib.Path(directory)
if not os.path.exists(self.directory):
Expand All @@ -86,21 +88,21 @@ def __init__(
self.total_size = 0
self.total_received = 0

self.condition = asyncio.Condition()
self._wait_on_memory = asyncio.Condition()

self.bytes_written = 0
self.bytes_read = 0

self._done = False
self._futures = set()
self._tasks = set()
self.diagnostics = defaultdict(float)

self._communicate_future = asyncio.create_task(self.communicate())
self._loop = loop or asyncio.get_event_loop()
self._loop = loop
self._exception = None

@property
def queue(self) -> asyncio.Queue[None]:
def _queue(self) -> asyncio.Queue[None]:
try:
return MultiFile._queues[self._loop]
except KeyError:
Expand All @@ -121,7 +123,7 @@ async def put(self, data: dict[str, list[pa.Table]]) -> None:
be written to that destination
"""
if self._exception:
raise self._exception
await self._maybe_raise_exception()

this_size = 0
for id, shards in data.items():
Expand All @@ -137,11 +139,12 @@ async def put(self, data: dict[str, list[pa.Table]]) -> None:

while MultiFile.total_size > MultiFile.memory_limit:
with self.time("waiting-on-memory"):
async with self.condition:
await self._maybe_raise_exception()
async with self._wait_on_memory:

try:
await asyncio.wait_for(
self.condition.wait(), 1
self._wait_on_memory.wait(), 1
) # Block until memory calms down
except asyncio.TimeoutError:
continue
Expand All @@ -164,21 +167,28 @@ async def communicate(self) -> None:

while not self._done:
with self.time("idle"):
await self._maybe_raise_exception()
if not self.shards:
await asyncio.sleep(0.1)
continue

await self.queue.get()

await self._queue.get()
# Shards must only be mutated below
assert self.shards, "MultiFile.shards was mutated unexpectedly"
id = max(self.sizes, key=self.sizes.__getitem__)
shards = self.shards.pop(id)
size = self.sizes.pop(id)

future = asyncio.create_task(self.process(id, shards, size))
task = asyncio.create_task(self.process(id, shards, size))
del shards
self._futures.add(future)
async with self.condition:
self.condition.notify()
self._tasks.add(task)

def _reset_count(task: asyncio.Task) -> None:
self._tasks.discard(task)
self._queue.put_nowait(None)

task.add_done_callback(_reset_count)
async with self._wait_on_memory:
self._wait_on_memory.notify()

async def process(self, id: str, shards: list[pa.Table], size: int) -> None:
"""Write one buffer to file
Expand Down Expand Up @@ -221,14 +231,15 @@ async def process(self, id: str, shards: list[pa.Table], size: int) -> None:
self.bytes_written += size
self.total_size -= size
MultiFile.total_size -= size
async with self.condition:
self.condition.notify()
await self.queue.put(None)
async with self._wait_on_memory:
self._wait_on_memory.notify()

def read(self, id: int) -> pa.Table:
def read(self, id: int | str) -> pa.Table:
"""Read a complete file back into memory"""
if self._exception:
raise self._exception
if not self._done:
raise RuntimeError("Tried to read from file before done.")
parts = []

try:
Expand All @@ -253,35 +264,45 @@ def read(self, id: int) -> pa.Table:
else:
raise KeyError(id)

async def flush(self) -> None:
"""Wait until all writes are finished"""
async def _maybe_raise_exception(self) -> None:
if self._exception:
assert self._done
await self._communicate_future
await asyncio.gather(*self._futures)
await asyncio.gather(*self._tasks)
raise self._exception

async def flush(self) -> None:
"""Wait until all writes are finished"""

while self.shards:
await self._maybe_raise_exception()
# If an exception arises while we're sleeping here we deadlock
await asyncio.sleep(0.05)

await asyncio.gather(*self._futures)
if all(future.done() for future in self._futures):
self._futures.clear()
await asyncio.gather(*self._tasks)
await self._maybe_raise_exception()

assert not self.total_size

self._done = True

await self._communicate_future

def close(self) -> None:
self._done = True
with contextlib.suppress(FileNotFoundError):
shutil.rmtree(self.directory)

def __enter__(self) -> MultiFile:
async def close(self) -> None:
try:
# XXX If there is an exception this will raise again during
# teardown. I don't think this is what we want to. Likely raising
# the exception on flushing is not ideal
if not self._done:
await self.flush()
finally:
with contextlib.suppress(FileNotFoundError):
shutil.rmtree(self.directory)

async def __aenter__(self) -> MultiFile:
return self

def __exit__(self, exc: Any, typ: Any, traceback: Any) -> None:
self.close()
async def __aexit__(self, exc: Any, typ: Any, traceback: Any) -> None:
await self.close()

@contextlib.contextmanager
def time(self, name: str) -> Iterator[None]:
Expand Down
Loading