Skip to content

Commit

Permalink
Clean up of unpack_remotedata() (#7322)
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk authored Nov 24, 2022
1 parent dd55b03 commit cff33d5
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 52 deletions.
44 changes: 43 additions & 1 deletion distributed/tests/test_utils_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@

import pytest

from dask.optimization import SubgraphCallable

from distributed.core import ConnectionPool
from distributed.utils_comm import gather_from_workers, pack_data, retry, subs_multiple
from distributed.utils_comm import (
WrappedKey,
gather_from_workers,
pack_data,
retry,
subs_multiple,
unpack_remotedata,
)
from distributed.utils_test import BrokenComm, gen_cluster


Expand Down Expand Up @@ -129,3 +138,36 @@ async def f():

assert n_calls == 6
assert sleep_calls == [0.0, 1.0, 3.0, 6.0, 6.0]


def test_unpack_remotedata():
def assert_eq(keys1: set[WrappedKey], keys2: set[WrappedKey]) -> None:
if len(keys1) != len(keys2):
assert False
if not keys1:
assert True
if not all(isinstance(k, WrappedKey) for k in keys1 & keys2):
assert False
assert sorted([k.key for k in keys1]) == sorted([k.key for k in keys2])

assert unpack_remotedata(1) == (1, set())
assert unpack_remotedata(()) == ((), set())

res, keys = unpack_remotedata(WrappedKey("mykey"))
assert res == "mykey"
assert_eq(keys, {WrappedKey("mykey")})

# Check unpack of SC that contains a wrapped key
sc = SubgraphCallable({"key": (WrappedKey("data"),)}, outkey="key", inkeys=["arg1"])
dsk = (sc, "arg1")
res, keys = unpack_remotedata(dsk)
assert res[0] != sc # Notice, the first item (the SC) has been changed
assert res[1:] == ("arg1", "data")
assert_eq(keys, {WrappedKey("data")})

# Check unpack of SC when it takes a wrapped key as argument
sc = SubgraphCallable({"key": ("arg1",)}, outkey="key", inkeys=[WrappedKey("arg1")])
dsk = (sc, "arg1")
res, keys = unpack_remotedata(dsk)
assert res == (sc, "arg1") # Notice, the first item (the SC) has NOT been changed
assert_eq(keys, set())
116 changes: 65 additions & 51 deletions distributed/utils_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
from functools import partial
from itertools import cycle
from typing import Any

from tlz import concat, drop, groupby, merge

Expand Down Expand Up @@ -162,90 +163,103 @@ async def scatter_to_workers(nthreads, data, rpc=rpc, report=True):
collection_types = (tuple, list, set, frozenset)


def unpack_remotedata(o, byte_keys=False, myset=None):
"""Unpack WrappedKey objects from collection
Returns original collection and set of all found WrappedKey objects
Examples
--------
>>> rd = WrappedKey('mykey')
>>> unpack_remotedata(1)
(1, set())
>>> unpack_remotedata(())
((), set())
>>> unpack_remotedata(rd)
('mykey', {WrappedKey('mykey')})
>>> unpack_remotedata([1, rd])
([1, 'mykey'], {WrappedKey('mykey')})
>>> unpack_remotedata({1: rd})
({1: 'mykey'}, {WrappedKey('mykey')})
>>> unpack_remotedata({1: [rd]})
({1: ['mykey']}, {WrappedKey('mykey')})
Use the ``byte_keys=True`` keyword to force string keys
>>> rd = WrappedKey(('x', 1))
>>> unpack_remotedata(rd, byte_keys=True)
("('x', 1)", {WrappedKey('('x', 1)')})
"""
if myset is None:
myset = set()
out = unpack_remotedata(o, byte_keys, myset)
return out, myset
def _unpack_remotedata_inner(
o: Any, byte_keys: bool, found_keys: set[WrappedKey]
) -> Any:
"""Inner implementation of `unpack_remotedata` that adds found wrapped keys to `found_keys`"""

typ = type(o)

if typ is tuple:
if not o:
return o
if type(o[0]) is SubgraphCallable:
sc = o[0]

# Unpack futures within the arguments of the subgraph callable
futures: set[WrappedKey] = set()
args = tuple(_unpack_remotedata_inner(i, byte_keys, futures) for i in o[1:])
found_keys.update(futures)

# Unpack futures within the subgraph callable itself
sc: SubgraphCallable = o[0]
futures = set()
dsk = {
k: unpack_remotedata(v, byte_keys, futures) for k, v in sc.dsk.items()
k: _unpack_remotedata_inner(v, byte_keys, futures)
for k, v in sc.dsk.items()
}
args = tuple(unpack_remotedata(i, byte_keys, futures) for i in o[1:])
if futures:
myset.update(futures)
futures = (
future_keys: tuple = ()
if futures: # If no futures is in the subgraph, we just use `sc` as-is
found_keys.update(futures)
future_keys = (
tuple(stringify(f.key) for f in futures)
if byte_keys
else tuple(f.key for f in futures)
)
inkeys = sc.inkeys + futures
return (
(SubgraphCallable(dsk, sc.outkey, inkeys, sc.name),)
+ args
+ futures
)
else:
return o
inkeys = tuple(sc.inkeys) + future_keys
sc = SubgraphCallable(dsk, sc.outkey, inkeys, sc.name)
return (sc,) + args + future_keys
else:
return tuple(unpack_remotedata(item, byte_keys, myset) for item in o)
return tuple(
_unpack_remotedata_inner(item, byte_keys, found_keys) for item in o
)
elif is_namedtuple_instance(o):
return typ(*[unpack_remotedata(item, byte_keys, myset) for item in o])
return typ(
*[_unpack_remotedata_inner(item, byte_keys, found_keys) for item in o]
)

if typ in collection_types:
if not o:
return o
outs = [unpack_remotedata(item, byte_keys, myset) for item in o]
outs = [_unpack_remotedata_inner(item, byte_keys, found_keys) for item in o]
return typ(outs)
elif typ is dict:
if o:
return {k: unpack_remotedata(v, byte_keys, myset) for k, v in o.items()}
return {
k: _unpack_remotedata_inner(v, byte_keys, found_keys)
for k, v in o.items()
}
else:
return o
elif issubclass(typ, WrappedKey): # TODO use type is Future
k = o.key
if byte_keys:
k = stringify(k)
myset.add(o)
found_keys.add(o)
return k
else:
return o


def unpack_remotedata(o: Any, byte_keys: bool = False) -> tuple[Any, set]:
"""Unpack WrappedKey objects from collection
Returns original collection and set of all found WrappedKey objects
Examples
--------
>>> rd = WrappedKey('mykey')
>>> unpack_remotedata(1)
(1, set())
>>> unpack_remotedata(())
((), set())
>>> unpack_remotedata(rd)
('mykey', {WrappedKey('mykey')})
>>> unpack_remotedata([1, rd])
([1, 'mykey'], {WrappedKey('mykey')})
>>> unpack_remotedata({1: rd})
({1: 'mykey'}, {WrappedKey('mykey')})
>>> unpack_remotedata({1: [rd]})
({1: ['mykey']}, {WrappedKey('mykey')})
Use the ``byte_keys=True`` keyword to force string keys
>>> rd = WrappedKey(('x', 1))
>>> unpack_remotedata(rd, byte_keys=True)
("('x', 1)", {WrappedKey('('x', 1)')})
"""
found_keys: set[Any] = set()
return _unpack_remotedata_inner(o, byte_keys, found_keys), found_keys


def pack_data(o, d, key_types=object):
"""Merge known data into tuple or dict
Expand Down

0 comments on commit cff33d5

Please sign in to comment.