Skip to content

Commit

Permalink
Fix for more broadcast dims
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 16, 2024
1 parent 4aa6bb9 commit 633b361
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def _vindex_like(
else:
da = da.chunk("auto")

like = like_da._variable._data
array = da._variable._data

import dask.array
Expand All @@ -201,30 +200,37 @@ def _vindex_like(
from xarray.core.dask_array_compat import reshape_blockwise

array = clone(array) # FIXME: add to dask

assert array.ndim == 1
dims = indexer.dims
axes = tuple(like_da.get_axis_num(dim) for dim in dims)
to_shape = tuple(size for ax, size in enumerate(like.shape) if ax in axes)
to_chunks = tuple(
chunksize for ax, chunksize in enumerate(like.chunks) if ax in axes
)
idxr = indexer._variable._data
# array = clone(array) # FIXME: add to dask

# dimensions for indexed result
out_dims = tuple(
itertools.chain(*(indexer.dims if this == dim else (this,) for this in da.dims))
)
out_chunks = tuple(
da.chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims
)
out_shape = tuple(da.sizes.get(dim, like_da.sizes[dim]) for dim in out_dims)
idxr = indexer._variable._data

# shuffle indices that can be reshaped blockwise to desired shape
core_dim_chunks = tuple(
chunks
for dim, chunks in zip(out_dims, out_chunks, strict=True)
if dim in indexer.dims
)
flat_indices = [
idxr[slicer].ravel().tolist() for slicer in slices_from_chunks(to_chunks)
idxr[slicer].ravel().tolist() for slicer in slices_from_chunks(core_dim_chunks)
]
shuffled = dask.array.shuffle(
array, flat_indices, axis=da.get_axis_num(dim), chunks="auto"
)
if shuffled.shape != to_shape:
shuffled = reshape_blockwise(shuffled, shape=to_shape, chunks=to_chunks)
# shuffle with `chunks="auto"` could change chunks, so we recalculate out_chunks
new_chunksizes = dict(zip(da.dims, shuffled.chunks, strict=True))
out_chunks = tuple(
new_chunksizes.get(dim, like_da.chunksizes[dim]) for dim in out_dims
)
if shuffled.shape != out_shape:
shuffled = reshape_blockwise(shuffled, shape=out_shape, chunks=out_chunks)
return Variable(dims=out_dims, data=shuffled, attrs=da.attrs)


Expand Down

0 comments on commit 633b361

Please sign in to comment.