From 6af547cdd9beac3b18420ccb204f801603e11519 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Date: Fri, 22 Mar 2024 19:30:44 -0700 Subject: [PATCH] Handle .oindex and .vindex for the PandasMultiIndexingAdapter and PandasIndexingAdapter (#8869) --- xarray/core/indexing.py | 103 ++++++++++++++++++++++++++++++++++------ 1 file changed, 89 insertions(+), 14 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 82ee4ccb0e4..e26c50c8b90 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1680,11 +1680,65 @@ def _convert_scalar(self, item): # a NumPy array. return to_0d_array(item) - def _oindex_get(self, indexer: OuterIndexer): - return self.__getitem__(indexer) + def _prepare_key(self, key: tuple[Any, ...]) -> tuple[Any, ...]: + if isinstance(key, tuple) and len(key) == 1: + # unpack key so it can index a pandas.Index object (pandas.Index + # objects don't like tuples) + (key,) = key - def _vindex_get(self, indexer: VectorizedIndexer): - return self.__getitem__(indexer) + return key + + def _handle_result( + self, result: Any + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + if isinstance(result, pd.Index): + return type(self)(result, dtype=self.dtype) + else: + return self._convert_scalar(result) + + def _oindex_get( + self, indexer: OuterIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + key = self._prepare_key(indexer.tuple) + + if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional + indexable = NumpyIndexingAdapter(np.asarray(self)) + return indexable.oindex[indexer] + + result = self.array[key] + + return self._handle_result(result) + + def _vindex_get( + self, indexer: VectorizedIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + key = self._prepare_key(indexer.tuple) + + if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional + indexable = NumpyIndexingAdapter(np.asarray(self)) + return indexable.vindex[indexer] + + result = self.array[key] + + return self._handle_result(result) def __getitem__( self, indexer: ExplicitIndexer @@ -1695,22 +1749,15 @@ def __getitem__( | np.datetime64 | np.timedelta64 ): - key = indexer.tuple - if isinstance(key, tuple) and len(key) == 1: - # unpack key so it can index a pandas.Index object (pandas.Index - # objects don't like tuples) - (key,) = key + key = self._prepare_key(indexer.tuple) if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional indexable = NumpyIndexingAdapter(np.asarray(self)) - return apply_indexer(indexable, indexer) + return indexable[indexer] result = self.array[key] - if isinstance(result, pd.Index): - return type(self)(result, dtype=self.dtype) - else: - return self._convert_scalar(result) + return self._handle_result(result) def transpose(self, order) -> pd.Index: return self.array # self.array should be always one-dimensional @@ -1766,6 +1813,34 @@ def _convert_scalar(self, item): item = item[idx] return super()._convert_scalar(item) + def _oindex_get( + self, indexer: OuterIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + result = super()._oindex_get(indexer) + if isinstance(result, type(self)): + result.level = self.level + return result + + def _vindex_get( + self, indexer: VectorizedIndexer + ) -> ( + PandasIndexingAdapter + | NumpyIndexingAdapter + | np.ndarray + | np.datetime64 + | np.timedelta64 + ): + result = super()._vindex_get(indexer) + if isinstance(result, type(self)): + result.level = self.level + return result + def __getitem__(self, indexer: ExplicitIndexer): result = super().__getitem__(indexer) if isinstance(result, type(self)):