Skip to content

An elegant way to guarantee single chunk along dim #2103

@crusaderky

Description

@crusaderky

Algorithms that are wrapped by xarray.apply_ufunc(dask='parallelized'), and in general most algorithms for which aren't embarassingly parallel and for which there isn't a sophisticated dask function that allows for multiple chunks, cannot have multiple chunks on their core dimensions.

I have lost count of how many times I prefixed my invocations of apply_ufunc on a DataArray with the same blurb, over and over again:

if x.chunks:
    x = x.chunk({dim: x.shape[x.dims.index(dim)]})

The reason why it looks so awful is that DataArray.shape, DataArray.dims, Variable.shape and Variable.dims are positional.

I can see a few possible solutions to the problem:

Design 1

Change DataArray.chunk etc. to accept a special chunk size, e.g. -1, which means "whatever the size of that dim is". The above would become:

if x.chunks:
    x = x.chunk({dim: -1})

which is much more bearable.
One could argue that the implementation would need to happen in dask.array.rechunk; on the other hand in dask it woulf feel silly, because already today you can do it in a very synthetic way:

x = x.rechunk({axis: x.shape[axis]})

I'm not overly fond of this solution as it would be rather obscure for anybody who isn't super familiar with the API documentation.

Design 2

Add properties to DataArray and Variable, ddims and dshape (happy to hear suggestions about better names), which would return dims and shape as a OrderedDict, just like Dataset.dims and Dataset.shape.

The above would become:

if x.chunks:
    x = x.chunk({dim: x.dshape[dim]})

Design 3

Change dask.array.rechunk to accept numpy.inf / math.inf as the chunk size. This makes sense, as the function already accepts chunk sizes that are larger than the shape - however, it's currently limited to int.
This is probably my personal favourite, and trivial to implement too.

The above would become:

if x.chunks:
    x = x.chunk({dim: np.inf})

Design 4

Introduce a convenience method for DataArray, Dataset, and Variable, ensure_single_chunk(*dims).
Below a prototype:

def ensure_single_chunk(a, *dims):
    """If a has dask backend and two or more chunks on dims, rechunk it so that they
    become single-chunked.
    This is typically a prerequisite for computing any algorithm along dim that is not
    embarassingly parallel (short of sophisticated implementations such as those
    found in the dask module).

    :param a:
        any xarray object
    :param str dims:
        one or more dims of a to rechunk
    :returns:
        copy of a, where all listed dims are guaranteed to be on a single dask chunk.
        if a has numpy backend, return a shallow copy of it.
    """
    if isinstance(a, xarray.Dataset):
        dims = set(dims)
        unknown_dims = dims - a.dims.keys()
        if unknown_dims:
            raise ValueError("dim(s) %s not found" % ",".join(unknown_dims))
        a = a.copy(deep=False)
        for k, v in a.variables.items():
            if v.chunks:
                a[k] = ensure_single_chunk(v, *(set(v.dims) & dims))
        return a

    if not isinstance(a, (xarray.DataArray, xarray.Variable)):
        raise TypeError('a must be a DataArray, Dataset, or Variable')

    if not a.chunks:
        # numpy backend
        return a.copy(deep=False)

    return a.chunk({
        dim: a.shape[a.dims.index(dim)]
        for dim in dims
    })

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions