diff --git a/flox/core.py b/flox/core.py index eb6ed13b..a8d543fd 100644 --- a/flox/core.py +++ b/flox/core.py @@ -639,7 +639,9 @@ def rechunk_for_cohorts( return array.rechunk({axis: newchunks}) -def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) -> DaskArray: +def rechunk_for_blockwise( + array: DaskArray, axis: T_Axis, labels: np.ndarray, *, force: bool = True +) -> DaskArray: """ Rechunks array so that group boundaries line up with chunk boundaries, allowing embarrassingly parallel group reductions. @@ -672,11 +674,17 @@ def rechunk_for_blockwise(array: DaskArray, axis: T_Axis, labels: np.ndarray) -> return array Δn = abs(len(newchunks) - len(chunks)) - if (Δn / len(chunks) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD) and ( - abs(max(newchunks) - max(chunks)) / max(chunks) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD + if force or ( + (Δn / len(chunks) < BLOCKWISE_RECHUNK_NUM_CHUNKS_THRESHOLD) + and ( + abs(max(newchunks) - max(chunks)) / max(chunks) < BLOCKWISE_RECHUNK_CHUNK_SIZE_THRESHOLD + ) ): + logger.debug("Rechunking to enable blockwise.") # Less than 25% change in number of chunks, let's do it return array.rechunk({axis: newchunks}) + else: + return array def reindex_( @@ -2496,7 +2504,7 @@ def groupby_reduce( ): # Let's try rechunking for sorted 1D by. (single_axis,) = axis_ - array = rechunk_for_blockwise(array, single_axis, by_) + array = rechunk_for_blockwise(array, single_axis, by_, force=False) if _is_first_last_reduction(func): if has_dask and nax != 1: