diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 404f6d485f9..239b2a0ac93 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -5,8 +5,17 @@ import pytest +from dask.optimization import SubgraphCallable + from distributed.core import ConnectionPool -from distributed.utils_comm import gather_from_workers, pack_data, retry, subs_multiple +from distributed.utils_comm import ( + WrappedKey, + gather_from_workers, + pack_data, + retry, + subs_multiple, + unpack_remotedata, +) from distributed.utils_test import BrokenComm, gen_cluster @@ -129,3 +138,36 @@ async def f(): assert n_calls == 6 assert sleep_calls == [0.0, 1.0, 3.0, 6.0, 6.0] + + +def test_unpack_remotedata(): + def assert_eq(keys1: set[WrappedKey], keys2: set[WrappedKey]) -> None: + if len(keys1) != len(keys2): + assert False + if not keys1: + assert True + if not all(isinstance(k, WrappedKey) for k in keys1 & keys2): + assert False + assert sorted([k.key for k in keys1]) == sorted([k.key for k in keys2]) + + assert unpack_remotedata(1) == (1, set()) + assert unpack_remotedata(()) == ((), set()) + + res, keys = unpack_remotedata(WrappedKey("mykey")) + assert res == "mykey" + assert_eq(keys, {WrappedKey("mykey")}) + + # Check unpack of SC that contains a wrapped key + sc = SubgraphCallable({"key": (WrappedKey("data"),)}, outkey="key", inkeys=["arg1"]) + dsk = (sc, "arg1") + res, keys = unpack_remotedata(dsk) + assert res[0] != sc # Notice, the first item (the SC) has been changed + assert res[1:] == ("arg1", "data") + assert_eq(keys, {WrappedKey("data")}) + + # Check unpack of SC when it takes a wrapped key as argument + sc = SubgraphCallable({"key": ("arg1",)}, outkey="key", inkeys=[WrappedKey("arg1")]) + dsk = (sc, "arg1") + res, keys = unpack_remotedata(dsk) + assert res == (sc, "arg1") # Notice, the first item (the SC) has NOT been changed + assert_eq(keys, set()) diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 5aca1a1a53d..d7ebd00bd29 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -6,6 +6,7 @@ from collections import defaultdict from functools import partial from itertools import cycle +from typing import Any from tlz import concat, drop, groupby, merge @@ -162,90 +163,103 @@ async def scatter_to_workers(nthreads, data, rpc=rpc, report=True): collection_types = (tuple, list, set, frozenset) -def unpack_remotedata(o, byte_keys=False, myset=None): - """Unpack WrappedKey objects from collection - - Returns original collection and set of all found WrappedKey objects - - Examples - -------- - >>> rd = WrappedKey('mykey') - >>> unpack_remotedata(1) - (1, set()) - >>> unpack_remotedata(()) - ((), set()) - >>> unpack_remotedata(rd) - ('mykey', {WrappedKey('mykey')}) - >>> unpack_remotedata([1, rd]) - ([1, 'mykey'], {WrappedKey('mykey')}) - >>> unpack_remotedata({1: rd}) - ({1: 'mykey'}, {WrappedKey('mykey')}) - >>> unpack_remotedata({1: [rd]}) - ({1: ['mykey']}, {WrappedKey('mykey')}) - - Use the ``byte_keys=True`` keyword to force string keys - - >>> rd = WrappedKey(('x', 1)) - >>> unpack_remotedata(rd, byte_keys=True) - ("('x', 1)", {WrappedKey('('x', 1)')}) - """ - if myset is None: - myset = set() - out = unpack_remotedata(o, byte_keys, myset) - return out, myset +def _unpack_remotedata_inner( + o: Any, byte_keys: bool, found_keys: set[WrappedKey] +) -> Any: + """Inner implementation of `unpack_remotedata` that adds found wrapped keys to `found_keys`""" typ = type(o) - if typ is tuple: if not o: return o if type(o[0]) is SubgraphCallable: - sc = o[0] + + # Unpack futures within the arguments of the subgraph callable + futures: set[WrappedKey] = set() + args = tuple(_unpack_remotedata_inner(i, byte_keys, futures) for i in o[1:]) + found_keys.update(futures) + + # Unpack futures within the subgraph callable itself + sc: SubgraphCallable = o[0] futures = set() dsk = { - k: unpack_remotedata(v, byte_keys, futures) for k, v in sc.dsk.items() + k: _unpack_remotedata_inner(v, byte_keys, futures) + for k, v in sc.dsk.items() } - args = tuple(unpack_remotedata(i, byte_keys, futures) for i in o[1:]) - if futures: - myset.update(futures) - futures = ( + future_keys: tuple = () + if futures: # If no futures is in the subgraph, we just use `sc` as-is + found_keys.update(futures) + future_keys = ( tuple(stringify(f.key) for f in futures) if byte_keys else tuple(f.key for f in futures) ) - inkeys = sc.inkeys + futures - return ( - (SubgraphCallable(dsk, sc.outkey, inkeys, sc.name),) - + args - + futures - ) - else: - return o + inkeys = tuple(sc.inkeys) + future_keys + sc = SubgraphCallable(dsk, sc.outkey, inkeys, sc.name) + return (sc,) + args + future_keys else: - return tuple(unpack_remotedata(item, byte_keys, myset) for item in o) + return tuple( + _unpack_remotedata_inner(item, byte_keys, found_keys) for item in o + ) elif is_namedtuple_instance(o): - return typ(*[unpack_remotedata(item, byte_keys, myset) for item in o]) + return typ( + *[_unpack_remotedata_inner(item, byte_keys, found_keys) for item in o] + ) if typ in collection_types: if not o: return o - outs = [unpack_remotedata(item, byte_keys, myset) for item in o] + outs = [_unpack_remotedata_inner(item, byte_keys, found_keys) for item in o] return typ(outs) elif typ is dict: if o: - return {k: unpack_remotedata(v, byte_keys, myset) for k, v in o.items()} + return { + k: _unpack_remotedata_inner(v, byte_keys, found_keys) + for k, v in o.items() + } else: return o elif issubclass(typ, WrappedKey): # TODO use type is Future k = o.key if byte_keys: k = stringify(k) - myset.add(o) + found_keys.add(o) return k else: return o +def unpack_remotedata(o: Any, byte_keys: bool = False) -> tuple[Any, set]: + """Unpack WrappedKey objects from collection + + Returns original collection and set of all found WrappedKey objects + + Examples + -------- + >>> rd = WrappedKey('mykey') + >>> unpack_remotedata(1) + (1, set()) + >>> unpack_remotedata(()) + ((), set()) + >>> unpack_remotedata(rd) + ('mykey', {WrappedKey('mykey')}) + >>> unpack_remotedata([1, rd]) + ([1, 'mykey'], {WrappedKey('mykey')}) + >>> unpack_remotedata({1: rd}) + ({1: 'mykey'}, {WrappedKey('mykey')}) + >>> unpack_remotedata({1: [rd]}) + ({1: ['mykey']}, {WrappedKey('mykey')}) + + Use the ``byte_keys=True`` keyword to force string keys + + >>> rd = WrappedKey(('x', 1)) + >>> unpack_remotedata(rd, byte_keys=True) + ("('x', 1)", {WrappedKey('('x', 1)')}) + """ + found_keys: set[Any] = set() + return _unpack_remotedata_inner(o, byte_keys, found_keys), found_keys + + def pack_data(o, d, key_types=object): """Merge known data into tuple or dict