Skip to content

Commit

Permalink
Shuffle SRP
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Oct 26, 2022
1 parent 6ae57f6 commit d91556a
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 89 deletions.
23 changes: 14 additions & 9 deletions distributed/shuffle/_multi_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from collections.abc import Iterator
from typing import Awaitable, Callable, Sequence

from tornado.ioloop import IOLoop

from dask.utils import parse_bytes

from distributed.utils import log_errors
Expand Down Expand Up @@ -56,14 +58,15 @@ class MultiComm:
memory_limit = parse_bytes("100 MiB")
max_connections = 10
_queues: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, asyncio.Queue[None]
IOLoop, asyncio.Queue[None]
] = weakref.WeakKeyDictionary()
total_size = 0
lock = threading.Lock()

def __init__(
self,
send: Callable[[str, list[bytes]], Awaitable[None]],
loop: IOLoop,
):
self.send = send
self.shards: dict[str, list[bytes]] = defaultdict(list)
Expand All @@ -74,8 +77,7 @@ def __init__(
self._futures: set[asyncio.Task] = set()
self._done = False
self.diagnostics: dict[str, float] = defaultdict(float)
self._loop = asyncio.get_event_loop()

self._loop = loop
self._communicate_task = asyncio.create_task(self.communicate())
self._exception: Exception | None = None

Expand Down Expand Up @@ -211,20 +213,23 @@ async def flush(self) -> None:

while self.shards:
await asyncio.sleep(0.05)

# FIXME: This needs lock protection to guarantee that shards are indeed
# empty and all associated futures are in self._futures but _process is
# locking as well. Either we'll need a RLock or a second lock
await asyncio.gather(*self._futures)
self._futures.clear()

assert not self.total_size

self._done = True
await self._communicate_task

async def close(self) -> None:
try:
await self.flush()
except Exception:
pass
# TODO: Should we flush here?
# TODO: Should this raise an exception if there is one?
# TODO: Should finished tasks remove themselves from futures, s.t. we only raise once. We do not raise multiple times, do we?
await self.flush()
self._done = True
await self._communicate_task

@contextlib.contextmanager
def time(self, name: str) -> Iterator[None]:
Expand Down
5 changes: 4 additions & 1 deletion distributed/shuffle/_multi_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from collections.abc import Callable, Iterator
from typing import TYPE_CHECKING, Any, BinaryIO

from tornado.ioloop import IOLoop

from dask.sizeof import sizeof
from dask.utils import parse_bytes

Expand Down Expand Up @@ -69,6 +71,7 @@ class MultiFile:
def __init__(
self,
directory: str,
loop: IOLoop,
dump: Callable[[Any, BinaryIO], None] = pickle.dump,
load: Callable[[BinaryIO], Any] = pickle.load,
sizeof: Callable[[list[pa.Table]], int] = sizeof,
Expand All @@ -95,7 +98,7 @@ def __init__(
self.diagnostics = defaultdict(float)

self._communicate_future = asyncio.create_task(self.communicate())
self._loop = asyncio.get_event_loop()
self._loop = loop
self._exception = None

@property
Expand Down
Loading

0 comments on commit d91556a

Please sign in to comment.