From 19c8bf91e02a594f2cdaf7fb1fa9741569400355 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 12 Jun 2023 13:55:56 +0200 Subject: [PATCH] Refactor P2P rechunk validation (#7890) --- distributed/shuffle/_rechunk.py | 31 +----------------------- distributed/shuffle/_worker_extension.py | 8 ------ 2 files changed, 1 insertion(+), 38 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index e58d0608f99..738c8034f79 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -1,8 +1,7 @@ from __future__ import annotations -import math from collections import defaultdict -from itertools import compress, product +from itertools import product from typing import TYPE_CHECKING, NamedTuple import dask @@ -73,34 +72,6 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: # Special case for empty array, as the algorithm below does not behave correctly return da.empty(x.shape, chunks=chunks, dtype=x.dtype) - old_chunks = x.chunks - new_chunks = chunks - - def is_unknown(dim: ChunkedAxis) -> bool: - return any(math.isnan(chunk) for chunk in dim) - - old_is_unknown = [is_unknown(dim) for dim in old_chunks] - new_is_unknown = [is_unknown(dim) for dim in new_chunks] - - if old_is_unknown != new_is_unknown or any( - new != old for new, old in compress(zip(old_chunks, new_chunks), old_is_unknown) - ): - raise ValueError( - "Chunks must be unchanging along dimensions with missing values.\n\n" - "A possible solution:\n x.compute_chunk_sizes()" - ) - - old_known = [dim for dim, unknown in zip(old_chunks, old_is_unknown) if not unknown] - new_known = [dim for dim, unknown in zip(new_chunks, new_is_unknown) if not unknown] - - old_sizes = [sum(o) for o in old_known] - new_sizes = [sum(n) for n in new_known] - - if old_sizes != new_sizes: - raise ValueError( - f"Cannot change dimensions from {old_sizes!r} to {new_sizes!r}" - ) - dsk: dict = {} token = tokenize(x, chunks) _barrier_key = barrier_key(ShuffleId(token)) diff --git a/distributed/shuffle/_worker_extension.py b/distributed/shuffle/_worker_extension.py index 28470647d22..5e40762475f 100644 --- a/distributed/shuffle/_worker_extension.py +++ b/distributed/shuffle/_worker_extension.py @@ -314,14 +314,6 @@ def __init__( memory_limiter_comms=memory_limiter_comms, memory_limiter_disk=memory_limiter_disk, ) - from dask.array.core import normalize_chunks - - # We rely on a canonical `np.nan` in `dask.array.rechunk.old_to_new` - # that passes an implicit identity check when testing for list equality. - # This does not work with (de)serialization, so we have to normalize the chunks - # here again to canonicalize `nan`s. - old = normalize_chunks(old) - new = normalize_chunks(new) self.old = old self.new = new partitions_of = defaultdict(list)