Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up of unpack_remotedata() #7322

Merged
merged 3 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion distributed/tests/test_utils_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@

import pytest

from dask.optimization import SubgraphCallable

from distributed.core import ConnectionPool
from distributed.utils_comm import gather_from_workers, pack_data, retry, subs_multiple
from distributed.utils_comm import (
WrappedKey,
gather_from_workers,
pack_data,
retry,
subs_multiple,
unpack_remotedata,
)
from distributed.utils_test import BrokenComm, gen_cluster


Expand Down Expand Up @@ -129,3 +138,36 @@ async def f():

assert n_calls == 6
assert sleep_calls == [0.0, 1.0, 3.0, 6.0, 6.0]


def test_unpack_remotedata():
def assert_eq(keys1: set[WrappedKey], keys2: set[WrappedKey]) -> None:
if len(keys1) != len(keys2):
assert False
if not keys1:
assert True
if not all(isinstance(k, WrappedKey) for k in keys1 & keys2):
assert False
assert sorted([k.key for k in keys1]) == sorted([k.key for k in keys2])

assert unpack_remotedata(1) == (1, set())
assert unpack_remotedata(()) == ((), set())

res, keys = unpack_remotedata(WrappedKey("mykey"))
assert res == "mykey"
assert_eq(keys, {WrappedKey("mykey")})

# Check unpack of SC that contains a wrapped key
sc = SubgraphCallable({"key": (WrappedKey("data"),)}, outkey="key", inkeys=["arg1"])
dsk = (sc, "arg1")
res, keys = unpack_remotedata(dsk)
assert res[0] != sc # Notice, the first item (the SC) has been changed
assert res[1:] == ("arg1", "data")
assert_eq(keys, {WrappedKey("data")})

# Check unpack of SC when it takes a wrapped key as argument
sc = SubgraphCallable({"key": ("arg1",)}, outkey="key", inkeys=[WrappedKey("arg1")])
dsk = (sc, "arg1")
res, keys = unpack_remotedata(dsk)
assert res == (sc, "arg1") # Notice, the first item (the SC) has NOT been changed
assert_eq(keys, set())
116 changes: 65 additions & 51 deletions distributed/utils_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
from functools import partial
from itertools import cycle
from typing import Any

from tlz import concat, drop, groupby, merge

Expand Down Expand Up @@ -162,90 +163,103 @@ async def scatter_to_workers(nthreads, data, rpc=rpc, report=True):
collection_types = (tuple, list, set, frozenset)


def unpack_remotedata(o, byte_keys=False, myset=None):
"""Unpack WrappedKey objects from collection

Returns original collection and set of all found WrappedKey objects

Examples
--------
>>> rd = WrappedKey('mykey')
>>> unpack_remotedata(1)
(1, set())
>>> unpack_remotedata(())
((), set())
>>> unpack_remotedata(rd)
('mykey', {WrappedKey('mykey')})
>>> unpack_remotedata([1, rd])
([1, 'mykey'], {WrappedKey('mykey')})
>>> unpack_remotedata({1: rd})
({1: 'mykey'}, {WrappedKey('mykey')})
>>> unpack_remotedata({1: [rd]})
({1: ['mykey']}, {WrappedKey('mykey')})

Use the ``byte_keys=True`` keyword to force string keys

>>> rd = WrappedKey(('x', 1))
>>> unpack_remotedata(rd, byte_keys=True)
("('x', 1)", {WrappedKey('('x', 1)')})
"""
if myset is None:
myset = set()
out = unpack_remotedata(o, byte_keys, myset)
return out, myset
def _unpack_remotedata_inner(
o: Any, byte_keys: bool, found_keys: set[WrappedKey]
) -> Any:
"""Inner implementation of `unpack_remotedata` that adds found wrapped keys to `found_keys`"""

typ = type(o)

if typ is tuple:
if not o:
return o
if type(o[0]) is SubgraphCallable:
sc = o[0]

# Unpack futures within the arguments of the subgraph callable
futures: set[WrappedKey] = set()
args = tuple(_unpack_remotedata_inner(i, byte_keys, futures) for i in o[1:])
found_keys.update(futures)

# Unpack futures within the subgraph callable itself
sc: SubgraphCallable = o[0]
futures = set()
dsk = {
k: unpack_remotedata(v, byte_keys, futures) for k, v in sc.dsk.items()
k: _unpack_remotedata_inner(v, byte_keys, futures)
for k, v in sc.dsk.items()
}
args = tuple(unpack_remotedata(i, byte_keys, futures) for i in o[1:])
if futures:
myset.update(futures)
futures = (
future_keys: tuple = ()
if futures: # If no futures is in the subgraph, we just use `sc` as-is
found_keys.update(futures)
future_keys = (
tuple(stringify(f.key) for f in futures)
if byte_keys
else tuple(f.key for f in futures)
)
inkeys = sc.inkeys + futures
return (
(SubgraphCallable(dsk, sc.outkey, inkeys, sc.name),)
+ args
+ futures
)
else:
return o
inkeys = tuple(sc.inkeys) + future_keys
sc = SubgraphCallable(dsk, sc.outkey, inkeys, sc.name)
return (sc,) + args + future_keys
else:
return tuple(unpack_remotedata(item, byte_keys, myset) for item in o)
return tuple(
_unpack_remotedata_inner(item, byte_keys, found_keys) for item in o
)
elif is_namedtuple_instance(o):
return typ(*[unpack_remotedata(item, byte_keys, myset) for item in o])
return typ(
*[_unpack_remotedata_inner(item, byte_keys, found_keys) for item in o]
)

if typ in collection_types:
if not o:
return o
outs = [unpack_remotedata(item, byte_keys, myset) for item in o]
outs = [_unpack_remotedata_inner(item, byte_keys, found_keys) for item in o]
return typ(outs)
elif typ is dict:
if o:
return {k: unpack_remotedata(v, byte_keys, myset) for k, v in o.items()}
return {
k: _unpack_remotedata_inner(v, byte_keys, found_keys)
for k, v in o.items()
}
else:
return o
elif issubclass(typ, WrappedKey): # TODO use type is Future
k = o.key
if byte_keys:
k = stringify(k)
myset.add(o)
found_keys.add(o)
return k
else:
return o


def unpack_remotedata(o: Any, byte_keys: bool = False) -> tuple[Any, set]:
"""Unpack WrappedKey objects from collection

Returns original collection and set of all found WrappedKey objects

Examples
--------
>>> rd = WrappedKey('mykey')
>>> unpack_remotedata(1)
(1, set())
>>> unpack_remotedata(())
((), set())
>>> unpack_remotedata(rd)
('mykey', {WrappedKey('mykey')})
>>> unpack_remotedata([1, rd])
([1, 'mykey'], {WrappedKey('mykey')})
>>> unpack_remotedata({1: rd})
({1: 'mykey'}, {WrappedKey('mykey')})
>>> unpack_remotedata({1: [rd]})
({1: ['mykey']}, {WrappedKey('mykey')})

Use the ``byte_keys=True`` keyword to force string keys

>>> rd = WrappedKey(('x', 1))
>>> unpack_remotedata(rd, byte_keys=True)
("('x', 1)", {WrappedKey('('x', 1)')})
"""
found_keys: set[Any] = set()
return _unpack_remotedata_inner(o, byte_keys, found_keys), found_keys


def pack_data(o, d, key_types=object):
"""Merge known data into tuple or dict

Expand Down