Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to specify chunks in iris.util.broadcast_to_shape #5620

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/whatsnew/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ This document explains the changes made to Iris for this release
lazy data from file. This will also speed up coordinate comparison.
(:pull:`5610`)

#. `@bouweandela`_ added the option to specify the Dask chunks of the target
array in :func:`iris.util.broadcast_to_shape`. (:pull:`5620`)

🔥 Deprecations
===============
Expand Down
25 changes: 25 additions & 0 deletions lib/iris/tests/unit/util/test_broadcast_to_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,31 @@ def test_lazy_masked(self, mocked_compute):
for j in range(4):
self.assertMaskedArrayEqual(b[i, :, j, :].compute().T, m.compute())

@mock.patch.object(dask.base, "compute", wraps=dask.base.compute)
def test_lazy_chunks(self, mocked_compute):
# chunks can be specified along with the target shape and are only used
# along new dimensions or on dimensions that have size 1 in the source
# array.
m = da.ma.masked_array(
data=[[1, 2, 3, 4, 5]],
mask=[[0, 1, 0, 0, 0]],
).rechunk((1, 2))
b = broadcast_to_shape(
m,
dim_map=(1, 2),
shape=(3, 4, 5),
chunks=(
1, # used because target is new dim
2, # used because input size 1
3, # not used because broadcast does not rechunk
),
)
mocked_compute.assert_not_called()
for i in range(3):
for j in range(4):
self.assertMaskedArrayEqual(b[i, j, :].compute(), m[0].compute())
assert b.chunks == ((1, 1, 1), (2, 2), (2, 2, 1))

def test_masked_degenerate(self):
# masked arrays can have degenerate masks too
a = np.random.random([2, 3])
Expand Down
28 changes: 24 additions & 4 deletions lib/iris/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import iris.exceptions


def broadcast_to_shape(array, shape, dim_map):
def broadcast_to_shape(array, shape, dim_map, chunks=None):
"""Broadcast an array to a given shape.

Each dimension of the array must correspond to a dimension in the
Expand All @@ -46,6 +46,14 @@ def broadcast_to_shape(array, shape, dim_map):
the index in *shape* which the dimension of *array* corresponds
to, so the first element of *dim_map* gives the index of *shape*
that corresponds to the first dimension of *array* etc.
chunks : :class:`tuple`, optional
HGWright marked this conversation as resolved.
Show resolved Hide resolved
If the source array is a :class:`dask.array.Array` and a value is
provided, then the result will use these chunks instead of the same
chunks as the source array. Setting chunks explicitly as part of
broadcast_to_shape is more efficient than rechunking afterwards. See
also :func:`dask.array.broadcast_to`. Note that the values provided
here will only be used along dimensions that are new on the result or
have size 1 on the source array.

Examples
--------
Expand All @@ -68,27 +76,39 @@ def broadcast_to_shape(array, shape, dim_map):
See more at :doc:`/userguide/real_and_lazy_data`.

"""
if isinstance(array, da.Array):
if chunks is not None:
chunks = list(chunks)
for src_idx, tgt_idx in enumerate(dim_map):
# Only use the specified chunks along new dimensions or on
# dimensions that have size 1 in the source array.
if array.shape[src_idx] != 1:
chunks[tgt_idx] = array.chunks[src_idx]
broadcast = functools.partial(da.broadcast_to, shape=shape, chunks=chunks)
else:
broadcast = functools.partial(np.broadcast_to, shape=shape)

n_orig_dims = len(array.shape)
n_new_dims = len(shape) - n_orig_dims
array = array.reshape(array.shape + (1,) * n_new_dims)

# Get dims in required order.
array = np.moveaxis(array, range(n_orig_dims), dim_map)
new_array = np.broadcast_to(array, shape)
new_array = broadcast(array)

if ma.isMA(array):
# broadcast_to strips masks so we need to handle them explicitly.
mask = ma.getmask(array)
if mask is ma.nomask:
new_mask = ma.nomask
else:
new_mask = np.broadcast_to(mask, shape)
new_mask = broadcast(mask)
new_array = ma.array(new_array, mask=new_mask)

elif is_lazy_masked_data(array):
# broadcast_to strips masks so we need to handle them explicitly.
mask = da.ma.getmaskarray(array)
new_mask = da.broadcast_to(mask, shape)
new_mask = broadcast(mask)
new_array = da.ma.masked_array(new_array, new_mask)

return new_array
Expand Down
Loading