Skip to content

Commit

Permalink
backcompat: vendor np.broadcast_shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jun 24, 2022
1 parent 8df0c2a commit 9f5e31b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
63 changes: 63 additions & 0 deletions xarray/core/npcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,66 @@ def sliding_window_view(
"midpoint",
"nearest",
]


if Version(np.__version__) < Version("1.20"):

def _broadcast_shape(*args):
"""Returns the shape of the arrays that would result from broadcasting the
supplied arrays against each other.
"""
# use the old-iterator because np.nditer does not handle size 0 arrays
# consistently
b = np.broadcast(*args[:32])
# unfortunately, it cannot handle 32 or more arguments directly
for pos in range(32, len(args), 31):
# ironically, np.broadcast does not properly handle np.broadcast
# objects (it treats them as scalars)
# use broadcasting to avoid allocating the full array
b = np.broadcast_to(0, b.shape)
b = np.broadcast(b, *args[pos : (pos + 31)])
return b.shape

def broadcast_shapes(*args):
"""
Broadcast the input shapes into a single shape.
:ref:`Learn more about broadcasting here <basics.broadcasting>`.
.. versionadded:: 1.20.0
Parameters
----------
`*args` : tuples of ints, or ints
The shapes to be broadcast against each other.
Returns
-------
tuple
Broadcasted shape.
Raises
------
ValueError
If the shapes are not compatible and cannot be broadcast according
to NumPy's broadcasting rules.
See Also
--------
broadcast
broadcast_arrays
broadcast_to
Examples
--------
>>> np.broadcast_shapes((1, 2), (3, 1), (3, 2))
(3, 2)
>>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
(5, 6, 7)
"""
arrays = [np.empty(x, dtype=[]) for x in args]
return _broadcast_shape(*arrays)

else:
from numpy import broadcast_shapes # noqa
3 changes: 2 additions & 1 deletion xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]

from .npcompat import broadcast_shapes
from .options import OPTIONS

try:
Expand Down Expand Up @@ -109,7 +110,7 @@ def _advanced_indexer_subspaces(key):
return (), ()

non_slices = [k for k in key if not isinstance(k, slice)]
ndim = len(np.broadcast_shapes(*[item.shape for item in non_slices]))
ndim = len(broadcast_shapes(*[item.shape for item in non_slices]))
mixed_positions = advanced_index_positions[0] + np.arange(ndim)
vindex_positions = np.arange(ndim)
return mixed_positions, vindex_positions
Expand Down

0 comments on commit 9f5e31b

Please sign in to comment.