From 245697eac194da652081a2448ff7e64669eaad13 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 13 Dec 2024 17:01:36 -0700 Subject: [PATCH] Clear up broadcasting --- xarray/core/missing.py | 56 +++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 45a37353d57..0c077ac5be8 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -3,6 +3,7 @@ import datetime as dt import itertools import warnings +from collections import ChainMap from collections.abc import Callable, Generator, Hashable, Sequence from functools import partial from numbers import Number @@ -710,41 +711,42 @@ def interpolate_variable( func, kwargs = _get_interpolator_nd(method, **kwargs) in_coords, result_coords = zip(*(v for v in indexes_coords.values()), strict=True) - # broadcast out manually to minize confusing behaviour - broadcast_result_coords = broadcast_variables(*result_coords) - result_dims = broadcast_result_coords[0].dims # input coordinates along which we are interpolation are core dimensions # the corresponding output coordinates may or may not have the same name, # so `all_in_core_dims` is also `exclude_dims` all_in_core_dims = set(indexes_coords) + result_dims = OrderedSet(itertools.chain(*(_.dims for _ in result_coords))) + result_sizes = ChainMap(*(_.sizes for _ in result_coords)) + # any dimensions on the output that are present on the input, but are not being # interpolated along are broadcast or loop dimensions along which we automatically # vectorize. Consider the problem in # https://github.com/pydata/xarray/issues/6799#issuecomment-2474126217 # In the following, dimension names are listed out in []. # # da[time, q, lat, lon].interp(q=bar[lat,lon]). Here `lat`, `lon` - # are input dimensions, present on the output, along which we vectorize. + # are input dimensions, present on the output, but are not the coordinates + # we are explicitly interpolating. These are the dimensions along which we vectorize. # We track these as "result broadcast dimensions". # `q` is the only input core dimensions, and changes size (disappears) # so it is in exclude_dims. - result_broadcast_dims = set( - itertools.chain(dim for dim in result_dims if dim not in all_in_core_dims) - ) + vectorize_dims = (result_dims - all_in_core_dims) & set(var.dims) # remove any output broadcast dimensions from the list of core dimensions - output_core_dims = tuple(d for d in result_dims if d not in result_broadcast_dims) + output_core_dims = tuple(d for d in result_dims if d not in vectorize_dims) input_core_dims = ( # all coordinates on the input that we interpolate along [tuple(indexes_coords)] # the input coordinates are always 1D at the moment, so we just need to list out their names + [tuple(_.dims) for _ in in_coords] # The last set of inputs are the coordinates we are interpolating to. - # These have been broadcast already for ease. - + [output_core_dims] * len(result_coords) + + [ + tuple(d for d in coord.dims if d not in vectorize_dims) + for coord in result_coords + ] ) - output_sizes = {k: broadcast_result_coords[0].sizes[k] for k in output_core_dims} + output_sizes = {k: result_sizes[k] for k in output_core_dims} # scipy.interpolate.interp1d always forces to float. dtype = float if not issubclass(var.dtype.type, np.inexact) else var.dtype @@ -752,17 +754,25 @@ def interpolate_variable( _interpnd, var, *in_coords, - *broadcast_result_coords, + *result_coords, input_core_dims=input_core_dims, output_core_dims=[output_core_dims], exclude_dims=all_in_core_dims, dask="parallelized", - kwargs=dict(interp_func=func, interp_kwargs=kwargs), + kwargs=dict( + interp_func=func, + interp_kwargs=kwargs, + # we leave broadcasting up to dask if possible + # but we need broadcasted values in _interpnd, so propagate that + # context (dimension names), and broadcast there + # This would be unnecessary if we could tell apply_ufunc + # to insert size-1 broadcast dimensions + result_coord_core_dims=input_core_dims[-len(result_coords) :], + ), # TODO: deprecate and have the user rechunk themselves dask_gufunc_kwargs=dict(output_sizes=output_sizes, allow_rechunk=True), output_dtypes=[dtype], - # if there are any broadcast dims on the result, we must vectorize on them - vectorize=bool(result_broadcast_dims), + vectorize=bool(vectorize_dims), keep_attrs=True, ) return result @@ -787,7 +797,11 @@ def _interp1d( def _interpnd( - data: np.ndarray, *coords: np.ndarray, interp_func: InterpCallable, interp_kwargs + data: np.ndarray, + *coords: np.ndarray, + interp_func: InterpCallable, + interp_kwargs, + result_coord_core_dims, ) -> np.ndarray: """ Core nD array interpolation routine. @@ -801,10 +815,12 @@ def _interpnd( # Convert everything to Variables, since that makes applying # `_localize` and `_floatize_x` much easier x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])] - new_x = [ - Variable([f"dim_{ndim + dim}" for dim in range(_x.ndim)], _x) - for _x in coords[n_x:] - ] + new_x = broadcast_variables( + *( + Variable(dims, _x) + for dims, _x in zip(result_coord_core_dims, coords[n_x:], strict=True) + ) + ) var = Variable([f"dim_{dim}" for dim in range(ndim)], data) if interp_kwargs.get("method") in ["linear", "nearest"]: