From 9f5e31b06e8480eea28e3c8406b56d0952014622 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 24 Jun 2022 15:38:07 -0600 Subject: [PATCH] backcompat: vendor np.broadcast_shapes --- xarray/core/npcompat.py | 63 +++++++++++++++++++++++++++++++++++++++++ xarray/core/nputils.py | 3 +- 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index 85a8f88aba6..118b1c74fc1 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -245,3 +245,66 @@ def sliding_window_view( "midpoint", "nearest", ] + + +if Version(np.__version__) < Version("1.20"): + + def _broadcast_shape(*args): + """Returns the shape of the arrays that would result from broadcasting the + supplied arrays against each other. + """ + # use the old-iterator because np.nditer does not handle size 0 arrays + # consistently + b = np.broadcast(*args[:32]) + # unfortunately, it cannot handle 32 or more arguments directly + for pos in range(32, len(args), 31): + # ironically, np.broadcast does not properly handle np.broadcast + # objects (it treats them as scalars) + # use broadcasting to avoid allocating the full array + b = np.broadcast_to(0, b.shape) + b = np.broadcast(b, *args[pos : (pos + 31)]) + return b.shape + + def broadcast_shapes(*args): + """ + Broadcast the input shapes into a single shape. + + :ref:`Learn more about broadcasting here `. + + .. versionadded:: 1.20.0 + + Parameters + ---------- + `*args` : tuples of ints, or ints + The shapes to be broadcast against each other. + + Returns + ------- + tuple + Broadcasted shape. + + Raises + ------ + ValueError + If the shapes are not compatible and cannot be broadcast according + to NumPy's broadcasting rules. + + See Also + -------- + broadcast + broadcast_arrays + broadcast_to + + Examples + -------- + >>> np.broadcast_shapes((1, 2), (3, 1), (3, 2)) + (3, 2) + + >>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7)) + (5, 6, 7) + """ + arrays = [np.empty(x, dtype=[]) for x in args] + return _broadcast_shape(*arrays) + +else: + from numpy import broadcast_shapes # noqa diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index df79018f12a..5d219c579ff 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -6,6 +6,7 @@ import pandas as pd from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] +from .npcompat import broadcast_shapes from .options import OPTIONS try: @@ -109,7 +110,7 @@ def _advanced_indexer_subspaces(key): return (), () non_slices = [k for k in key if not isinstance(k, slice)] - ndim = len(np.broadcast_shapes(*[item.shape for item in non_slices])) + ndim = len(broadcast_shapes(*[item.shape for item in non_slices])) mixed_positions = advanced_index_positions[0] + np.arange(ndim) vindex_positions = np.arange(ndim) return mixed_positions, vindex_positions