Skip to content

Commit

Permalink
refactor indexing.py: introduce .oindex for Explicitly Indexed Ar…
Browse files Browse the repository at this point in the history
…rays (pydata#8750)

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 23, 2024
1 parent 01f7b4f commit d9760f3
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 28 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ v2024.03.0 (unreleased)
New Features
~~~~~~~~~~~~

- Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`)
By `Anderson Banihirwe <https://github.com/andersy005>`_.


Breaking changes
~~~~~~~~~~~~~~~~
Expand All @@ -44,6 +47,7 @@ Internal Changes
~~~~~~~~~~~~~~~~



.. _whats-new.2024.02.0:

v2024.02.0 (Feb 19, 2024)
Expand Down
83 changes: 62 additions & 21 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,21 @@ def as_integer_slice(value):
return slice(start, stop, step)


class IndexCallable:
"""Provide getitem syntax for a callable object."""

__slots__ = ("func",)

def __init__(self, func):
self.func = func

def __getitem__(self, key):
return self.func(key)

def __setitem__(self, key, value):
raise NotImplementedError


class BasicIndexer(ExplicitIndexer):
"""Tuple for basic indexing.
Expand Down Expand Up @@ -470,6 +485,13 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
# Note this is the base class for all lazy indexing classes
return np.asarray(self.get_duck_array(), dtype=dtype)

def _oindex_get(self, key):
raise NotImplementedError("This method should be overridden")

@property
def oindex(self):
return IndexCallable(self._oindex_get)


class ImplicitToExplicitIndexingAdapter(NDArrayMixin):
"""Wrap an array, converting tuples into the indicated explicit indexer."""
Expand Down Expand Up @@ -560,6 +582,9 @@ def get_duck_array(self):
def transpose(self, order):
return LazilyVectorizedIndexedArray(self.array, self.key).transpose(order)

def _oindex_get(self, indexer):
return type(self)(self.array, self._updated_key(indexer))

def __getitem__(self, indexer):
if isinstance(indexer, VectorizedIndexer):
array = LazilyVectorizedIndexedArray(self.array, self.key)
Expand Down Expand Up @@ -663,6 +688,9 @@ def _ensure_copied(self):
def get_duck_array(self):
return self.array.get_duck_array()

def _oindex_get(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

def __getitem__(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

Expand Down Expand Up @@ -696,6 +724,9 @@ def get_duck_array(self):
self._ensure_cached()
return self.array.get_duck_array()

def _oindex_get(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

def __getitem__(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))

Expand Down Expand Up @@ -1332,6 +1363,10 @@ def _indexing_array_and_key(self, key):
def transpose(self, order):
return self.array.transpose(order)

def _oindex_get(self, key):
array, key = self._indexing_array_and_key(key)
return array[key]

def __getitem__(self, key):
array, key = self._indexing_array_and_key(key)
return array[key]
Expand Down Expand Up @@ -1376,16 +1411,19 @@ def __init__(self, array):
)
self.array = array

def _oindex_get(self, key):
# manual orthogonal indexing (implemented like DaskIndexingAdapter)
key = key.tuple
value = self.array
for axis, subkey in reversed(list(enumerate(key))):
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
return value

def __getitem__(self, key):
if isinstance(key, BasicIndexer):
return self.array[key.tuple]
elif isinstance(key, OuterIndexer):
# manual orthogonal indexing (implemented like DaskIndexingAdapter)
key = key.tuple
value = self.array
for axis, subkey in reversed(list(enumerate(key))):
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
return value
return self.oindex[key]
else:
if isinstance(key, VectorizedIndexer):
raise TypeError("Vectorized indexing is not supported")
Expand All @@ -1395,11 +1433,10 @@ def __getitem__(self, key):
def __setitem__(self, key, value):
if isinstance(key, (BasicIndexer, OuterIndexer)):
self.array[key.tuple] = value
elif isinstance(key, VectorizedIndexer):
raise TypeError("Vectorized indexing is not supported")
else:
if isinstance(key, VectorizedIndexer):
raise TypeError("Vectorized indexing is not supported")
else:
raise TypeError(f"Unrecognized indexer: {key}")
raise TypeError(f"Unrecognized indexer: {key}")

def transpose(self, order):
xp = self.array.__array_namespace__()
Expand All @@ -1417,24 +1454,25 @@ def __init__(self, array):
"""
self.array = array

def __getitem__(self, key):
def _oindex_get(self, key):
key = key.tuple
try:
return self.array[key]
except NotImplementedError:
# manual orthogonal indexing
value = self.array
for axis, subkey in reversed(list(enumerate(key))):
value = value[(slice(None),) * axis + (subkey,)]
return value

def __getitem__(self, key):
if isinstance(key, BasicIndexer):
return self.array[key.tuple]
elif isinstance(key, VectorizedIndexer):
return self.array.vindex[key.tuple]
else:
assert isinstance(key, OuterIndexer)
key = key.tuple
try:
return self.array[key]
except NotImplementedError:
# manual orthogonal indexing.
# TODO: port this upstream into dask in a saner way.
value = self.array
for axis, subkey in reversed(list(enumerate(key))):
value = value[(slice(None),) * axis + (subkey,)]
return value
return self.oindex[key]

def __setitem__(self, key, value):
if isinstance(key, BasicIndexer):
Expand Down Expand Up @@ -1510,6 +1548,9 @@ def _convert_scalar(self, item):
# a NumPy array.
return to_0d_array(item)

def _oindex_get(self, key):
return self.__getitem__(key)

def __getitem__(
self, indexer
) -> (
Expand Down
22 changes: 15 additions & 7 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@
maybe_coerce_to_str,
)
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
from xarray.namedarray.pycompat import (
integer_types,
is_0d_dask_array,
to_duck_array,
)
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array

NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
indexing.ExplicitlyIndexed,
Expand Down Expand Up @@ -761,7 +757,14 @@ def __getitem__(self, key) -> Self:
array `x.values` directly.
"""
dims, indexer, new_order = self._broadcast_indexes(key)
data = as_indexable(self._data)[indexer]
indexable = as_indexable(self._data)

if isinstance(indexer, BasicIndexer):
data = indexable[indexer]
elif isinstance(indexer, OuterIndexer):
data = indexable.oindex[indexer]
else:
data = indexable[indexer]
if new_order:
data = np.moveaxis(data, range(len(new_order)), new_order)
return self._finalize_indexing_result(dims, data)
Expand Down Expand Up @@ -794,7 +797,12 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA):
else:
actual_indexer = indexer

data = as_indexable(self._data)[actual_indexer]
indexable = as_indexable(self._data)

if isinstance(indexer, OuterIndexer):
data = indexable.oindex[indexer]
else:
data = indexable[actual_indexer]
mask = indexing.create_mask(indexer, self.shape, data)
# we need to invert the mask in order to pass data first. This helps
# pint to choose the correct unit
Expand Down

0 comments on commit d9760f3

Please sign in to comment.