From d9760f30662b219182cdc8dedc04dfbe7771942b Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Date: Fri, 23 Feb 2024 15:34:42 -0800 Subject: [PATCH] refactor `indexing.py`: introduce `.oindex` for Explicitly Indexed Arrays (#8750) Co-authored-by: Deepak Cherian Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/whats-new.rst | 4 ++ xarray/core/indexing.py | 83 ++++++++++++++++++++++++++++++----------- xarray/core/variable.py | 22 +++++++---- 3 files changed, 81 insertions(+), 28 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ece209e09ae..80e53a5ee22 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,9 @@ v2024.03.0 (unreleased) New Features ~~~~~~~~~~~~ +- Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`) + By `Anderson Banihirwe `_. + Breaking changes ~~~~~~~~~~~~~~~~ @@ -44,6 +47,7 @@ Internal Changes ~~~~~~~~~~~~~~~~ + .. _whats-new.2024.02.0: v2024.02.0 (Feb 19, 2024) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 7331ab1a056..43867bc552b 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -325,6 +325,21 @@ def as_integer_slice(value): return slice(start, stop, step) +class IndexCallable: + """Provide getitem syntax for a callable object.""" + + __slots__ = ("func",) + + def __init__(self, func): + self.func = func + + def __getitem__(self, key): + return self.func(key) + + def __setitem__(self, key, value): + raise NotImplementedError + + class BasicIndexer(ExplicitIndexer): """Tuple for basic indexing. @@ -470,6 +485,13 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: # Note this is the base class for all lazy indexing classes return np.asarray(self.get_duck_array(), dtype=dtype) + def _oindex_get(self, key): + raise NotImplementedError("This method should be overridden") + + @property + def oindex(self): + return IndexCallable(self._oindex_get) + class ImplicitToExplicitIndexingAdapter(NDArrayMixin): """Wrap an array, converting tuples into the indicated explicit indexer.""" @@ -560,6 +582,9 @@ def get_duck_array(self): def transpose(self, order): return LazilyVectorizedIndexedArray(self.array, self.key).transpose(order) + def _oindex_get(self, indexer): + return type(self)(self.array, self._updated_key(indexer)) + def __getitem__(self, indexer): if isinstance(indexer, VectorizedIndexer): array = LazilyVectorizedIndexedArray(self.array, self.key) @@ -663,6 +688,9 @@ def _ensure_copied(self): def get_duck_array(self): return self.array.get_duck_array() + def _oindex_get(self, key): + return type(self)(_wrap_numpy_scalars(self.array[key])) + def __getitem__(self, key): return type(self)(_wrap_numpy_scalars(self.array[key])) @@ -696,6 +724,9 @@ def get_duck_array(self): self._ensure_cached() return self.array.get_duck_array() + def _oindex_get(self, key): + return type(self)(_wrap_numpy_scalars(self.array[key])) + def __getitem__(self, key): return type(self)(_wrap_numpy_scalars(self.array[key])) @@ -1332,6 +1363,10 @@ def _indexing_array_and_key(self, key): def transpose(self, order): return self.array.transpose(order) + def _oindex_get(self, key): + array, key = self._indexing_array_and_key(key) + return array[key] + def __getitem__(self, key): array, key = self._indexing_array_and_key(key) return array[key] @@ -1376,16 +1411,19 @@ def __init__(self, array): ) self.array = array + def _oindex_get(self, key): + # manual orthogonal indexing (implemented like DaskIndexingAdapter) + key = key.tuple + value = self.array + for axis, subkey in reversed(list(enumerate(key))): + value = value[(slice(None),) * axis + (subkey, Ellipsis)] + return value + def __getitem__(self, key): if isinstance(key, BasicIndexer): return self.array[key.tuple] elif isinstance(key, OuterIndexer): - # manual orthogonal indexing (implemented like DaskIndexingAdapter) - key = key.tuple - value = self.array - for axis, subkey in reversed(list(enumerate(key))): - value = value[(slice(None),) * axis + (subkey, Ellipsis)] - return value + return self.oindex[key] else: if isinstance(key, VectorizedIndexer): raise TypeError("Vectorized indexing is not supported") @@ -1395,11 +1433,10 @@ def __getitem__(self, key): def __setitem__(self, key, value): if isinstance(key, (BasicIndexer, OuterIndexer)): self.array[key.tuple] = value + elif isinstance(key, VectorizedIndexer): + raise TypeError("Vectorized indexing is not supported") else: - if isinstance(key, VectorizedIndexer): - raise TypeError("Vectorized indexing is not supported") - else: - raise TypeError(f"Unrecognized indexer: {key}") + raise TypeError(f"Unrecognized indexer: {key}") def transpose(self, order): xp = self.array.__array_namespace__() @@ -1417,24 +1454,25 @@ def __init__(self, array): """ self.array = array - def __getitem__(self, key): + def _oindex_get(self, key): + key = key.tuple + try: + return self.array[key] + except NotImplementedError: + # manual orthogonal indexing + value = self.array + for axis, subkey in reversed(list(enumerate(key))): + value = value[(slice(None),) * axis + (subkey,)] + return value + def __getitem__(self, key): if isinstance(key, BasicIndexer): return self.array[key.tuple] elif isinstance(key, VectorizedIndexer): return self.array.vindex[key.tuple] else: assert isinstance(key, OuterIndexer) - key = key.tuple - try: - return self.array[key] - except NotImplementedError: - # manual orthogonal indexing. - # TODO: port this upstream into dask in a saner way. - value = self.array - for axis, subkey in reversed(list(enumerate(key))): - value = value[(slice(None),) * axis + (subkey,)] - return value + return self.oindex[key] def __setitem__(self, key, value): if isinstance(key, BasicIndexer): @@ -1510,6 +1548,9 @@ def _convert_scalar(self, item): # a NumPy array. return to_0d_array(item) + def _oindex_get(self, key): + return self.__getitem__(key) + def __getitem__( self, indexer ) -> ( diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 8d76cfbe004..6834931fa11 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -41,11 +41,7 @@ maybe_coerce_to_str, ) from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions -from xarray.namedarray.pycompat import ( - integer_types, - is_0d_dask_array, - to_duck_array, -) +from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( indexing.ExplicitlyIndexed, @@ -761,7 +757,14 @@ def __getitem__(self, key) -> Self: array `x.values` directly. """ dims, indexer, new_order = self._broadcast_indexes(key) - data = as_indexable(self._data)[indexer] + indexable = as_indexable(self._data) + + if isinstance(indexer, BasicIndexer): + data = indexable[indexer] + elif isinstance(indexer, OuterIndexer): + data = indexable.oindex[indexer] + else: + data = indexable[indexer] if new_order: data = np.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) @@ -794,7 +797,12 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): else: actual_indexer = indexer - data = as_indexable(self._data)[actual_indexer] + indexable = as_indexable(self._data) + + if isinstance(indexer, OuterIndexer): + data = indexable.oindex[indexer] + else: + data = indexable[actual_indexer] mask = indexing.create_mask(indexer, self.shape, data) # we need to invert the mask in order to pass data first. This helps # pint to choose the correct unit