Skip to content

Commit

Permalink
Add xr.unify_chunks() top level method (#5445)
Browse files Browse the repository at this point in the history
Co-authored-by: crusaderky <crusaderky@gmail.com>
  • Loading branch information
malmans2 and crusaderky authored Jun 16, 2021
1 parent 4e61a26 commit fe87162
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 36 deletions.
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Top-level functions
map_blocks
show_versions
set_options
unify_chunks

Dataset
=======
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ v0.18.3 (unreleased)

New Features
~~~~~~~~~~~~
- New top-level function :py:func:`unify_chunks`.
By `Mattia Almansi <https://github.com/malmans2>`_.
- Allow assigning values to a subset of a dataset using positional or label-based
indexing (:issue:`3015`, :pull:`5362`).
By `Matthias Göbel <https://github.com/matzegoebel>`_.
Expand Down
3 changes: 2 additions & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .core.alignment import align, broadcast
from .core.combine import combine_by_coords, combine_nested
from .core.common import ALL_DIMS, full_like, ones_like, zeros_like
from .core.computation import apply_ufunc, corr, cov, dot, polyval, where
from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where
from .core.concat import concat
from .core.dataarray import DataArray
from .core.dataset import Dataset
Expand Down Expand Up @@ -74,6 +74,7 @@
"save_mfdataset",
"set_options",
"show_versions",
"unify_chunks",
"where",
"zeros_like",
# Classes
Expand Down
64 changes: 64 additions & 0 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Functions for applying functions that act on arrays to xarray's labeled data.
"""
from __future__ import annotations

import functools
import itertools
import operator
Expand All @@ -19,6 +21,7 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand All @@ -34,8 +37,11 @@

if TYPE_CHECKING:
from .coordinates import Coordinates # noqa
from .dataarray import DataArray
from .dataset import Dataset

T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})
Expand Down Expand Up @@ -1721,3 +1727,61 @@ def _calc_idxminmax(
res.attrs = indx.attrs

return res


def unify_chunks(*objects: T_DSorDA) -> Tuple[T_DSorDA, ...]:
"""
Given any number of Dataset and/or DataArray objects, returns
new objects with unified chunk size along all chunked dimensions.
Returns
-------
unified (DataArray or Dataset) – Tuple of objects with the same type as
*objects with consistent chunk sizes for all dask-array variables
See Also
--------
dask.array.core.unify_chunks
"""
from .dataarray import DataArray

# Convert all objects to datasets
datasets = [
obj._to_temp_dataset() if isinstance(obj, DataArray) else obj.copy()
for obj in objects
]

# Get argumets to pass into dask.array.core.unify_chunks
unify_chunks_args = []
sizes: dict[Hashable, int] = {}
for ds in datasets:
for v in ds._variables.values():
if v.chunks is not None:
# Check that sizes match across different datasets
for dim, size in v.sizes.items():
try:
if sizes[dim] != size:
raise ValueError(
f"Dimension {dim!r} size mismatch: {sizes[dim]} != {size}"
)
except KeyError:
sizes[dim] = size
unify_chunks_args += [v._data, v._dims]

# No dask arrays: Return inputs
if not unify_chunks_args:
return objects

# Run dask.array.core.unify_chunks
from dask.array.core import unify_chunks

_, dask_data = unify_chunks(*unify_chunks_args)
dask_data_iter = iter(dask_data)
out = []
for obj, ds in zip(objects, datasets):
for k, v in ds._variables.items():
if v.chunks is not None:
ds._variables[k] = v.copy(data=next(dask_data_iter))
out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds)

return tuple(out)
5 changes: 3 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from .arithmetic import DataArrayArithmetic
from .common import AbstractArray, DataWithCoords
from .computation import unify_chunks
from .coordinates import (
DataArrayCoordinates,
assert_coordinate_consistent,
Expand Down Expand Up @@ -3686,8 +3687,8 @@ def unify_chunks(self) -> "DataArray":
--------
dask.array.core.unify_chunks
"""
ds = self._to_temp_dataset().unify_chunks()
return self._from_temp_dataset(ds)

return unify_chunks(self)[0]

def map_blocks(
self,
Expand Down
33 changes: 2 additions & 31 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align
from .arithmetic import DatasetArithmetic
from .common import DataWithCoords, _contains_datetime_like_objects
from .computation import unify_chunks
from .coordinates import (
DatasetCoordinates,
assert_coordinate_consistent,
Expand Down Expand Up @@ -6566,37 +6567,7 @@ def unify_chunks(self) -> "Dataset":
dask.array.core.unify_chunks
"""

try:
self.chunks
except ValueError: # "inconsistent chunks"
pass
else:
# No variables with dask backend, or all chunks are already aligned
return self.copy()

# import dask is placed after the quick exit test above to allow
# running this method if dask isn't installed and there are no chunks
import dask.array

ds = self.copy()

dims_pos_map = {dim: index for index, dim in enumerate(ds.dims)}

dask_array_names = []
dask_unify_args = []
for name, variable in ds.variables.items():
if isinstance(variable.data, dask.array.Array):
dims_tuple = [dims_pos_map[dim] for dim in variable.dims]
dask_array_names.append(name)
dask_unify_args.append(variable.data)
dask_unify_args.append(dims_tuple)

_, rechunked_arrays = dask.array.core.unify_chunks(*dask_unify_args)

for name, new_array in zip(dask_array_names, rechunked_arrays):
ds.variables[name]._data = new_array

return ds
return unify_chunks(self)[0]

def map_blocks(
self,
Expand Down
18 changes: 16 additions & 2 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,12 +1069,26 @@ def test_unify_chunks(map_ds):
with pytest.raises(ValueError, match=r"inconsistent chunks"):
ds_copy.chunks

expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5), "z": (4,)}
expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5)}
with raise_if_dask_computes():
actual_chunks = ds_copy.unify_chunks().chunks
expected_chunks == actual_chunks
assert actual_chunks == expected_chunks
assert_identical(map_ds, ds_copy.unify_chunks())

out_a, out_b = xr.unify_chunks(ds_copy.cxy, ds_copy.drop_vars("cxy"))
assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5))
assert out_b.chunks == expected_chunks

# Test unordered dims
da = ds_copy["cxy"]
out_a, out_b = xr.unify_chunks(da.chunk({"x": -1}), da.T.chunk({"y": -1}))
assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5))
assert out_b.chunks == ((5, 5, 5, 5), (4, 4, 2))

# Test mismatch
with pytest.raises(ValueError, match=r"Dimension 'x' size mismatch: 10 != 2"):
xr.unify_chunks(da, da.isel(x=slice(2)))


@pytest.mark.parametrize("obj", [make_ds(), make_da()])
@pytest.mark.parametrize(
Expand Down

0 comments on commit fe87162

Please sign in to comment.