From 4a595dff198edfc6163fd9bb6d2b3c095320ac2b Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 16:22:12 +0200 Subject: [PATCH 1/8] Add coordinate transform classes from prototype --- xarray/core/coordinate_transform.py | 74 ++++++++++++++ xarray/core/indexes.py | 111 +++++++++++++++++++++ xarray/core/indexing.py | 145 ++++++++++++++++++++++++++++ 3 files changed, 330 insertions(+) create mode 100644 xarray/core/coordinate_transform.py diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py new file mode 100644 index 00000000000..1d4db3e9b7e --- /dev/null +++ b/xarray/core/coordinate_transform.py @@ -0,0 +1,74 @@ +from typing import Any, Iterable, Hashable, Mapping + +import numpy as np + + +class CoordinateTransform: + """Abstract coordinate transform with dimension & coordinate names.""" + + coord_names: tuple[Hashable, ...] + dims: tuple[str, ...] + dim_size: dict[str, int] + dtype: Any + + def __init__( + self, + coord_names: Iterable[Hashable], + dim_size: Mapping[str, int], + dtype: Any = np.dtype(np.float64), + ): + self.coord_names = tuple(coord_names) + self.dims = tuple(dim_size) + self.dim_size = dict(dim_size) + self.dtype = dtype + + def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: + """Perform grid -> world coordinate transformation. + + Parameters + ---------- + dim_positions : dict + Grid location(s) along each dimension (axis). + + Returns + ------- + coord_labels : dict + World coordinate labels. + + """ + # TODO: cache the results in order to avoid re-computing + # all labels when accessing the values of each coordinate one at a time + raise NotImplementedError + + def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]: + """Perform world -> grid coordinate reverse transformation. + + Parameters + ---------- + labels : dict + World coordinate labels. + + Returns + ------- + dim_positions : dict + Grid relative location(s) along each dimension (axis). + + """ + raise NotImplementedError + + def equals(self, other: "CoordinateTransform") -> bool: + """Check equality with another CoordinateTransform of the same kind.""" + raise NotImplementedError + + def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]: + """Returns all "world" coordinate labels.""" + if dims is None: + dims = self.dims + + positions = np.meshgrid( + *[np.arange(self.dim_size[d]) for d in dims], + indexing="ij", + ) + dim_positions = {dim: positions[i] for i, dim in enumerate(dims)} + + return self.forward(dim_positions) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 5abc2129e3e..8d90c955bfe 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -10,6 +10,7 @@ import pandas as pd from xarray.core import formatting, nputils, utils +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexing import ( IndexSelResult, PandasIndexingAdapter, @@ -1372,6 +1373,116 @@ def rename(self, name_dict, dims_dict): ) +class CoordinateTransformIndex(Index): + """Xarray index abstract class for transformation between "pixel" + and "world" coordinates. + + """ + + transform: CoordinateTransform + + def __init__( + self, + transform: CoordinateTransform, + ): + self.transform = transform + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> IndexVars: + new_variables = {} + + for name in self.transform.coord_names: + # copy attributes, if any + attrs: Mapping[Hashable, Any] | None + + if variables is not None and name in variables: + var = variables[name] + attrs = var.attrs + else: + attrs = None + + data = CoordinateTransformIndexingAdapter(self.transform, name) + new_variables[name] = Variable(self.transform.dims, data, attrs=attrs) + + return new_variables + + def create_coordinates(self) -> Coordinates: + # TODO: move this in xarray.Index base class? + variables = self.create_variables() + indexes = {name: self for name in variables} + return xr.Coordinates(coords=variables, indexes=indexes) + + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Self | None: + # TODO: support returning a new index (e.g., possible to re-calculate the + # the transform or calculate another transform on a reduced dimension space) + return None + + def sel( + self, labels: dict[Any, Any], method=None, tolerance=None + ) -> IndexSelResult: + if method != "nearest": + raise ValueError( + "CoordinateTransformIndex only supports selection with method='nearest'" + ) + + labels_set = set(labels) + coord_names_set = set(self.transform.coord_names) + + missing_labels = coord_names_set - labels_set + if missing_labels: + raise ValueError( + f"missing labels for coordinate(s): {','.join(missing_labels)}." + ) + + label0_obj = next(iter(labels.values())) + dim_size0 = getattr(label0_obj, "sizes", None) + + is_xr_obj = [ + isinstance(label, (xr.DataArray, xr.Variable)) for label in labels.values() + ] + if not all(is_xr_obj): + raise TypeError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with either xarray.DataArray or xarray.Variable objects." + ) + dim_size = [getattr(label, "sizes", None) for label in labels.values()] + if any([ds != dim_size0 for ds in dim_size]): + raise ValueError( + "CoordinateTransformIndex only supports advanced (point-wise) indexing " + "with xarray.DataArray or xarray.Variable objects of macthing dimensions." + ) + + coord_labels = { + name: labels[name].values for name in self.transform.coord_names + } + dim_positions = self.transform.reverse(coord_labels) + + results = {} + for dim, pos in dim_positions.items(): + if isinstance(label0_obj, Variable): + xr_pos = Variable(label.dims, idx) + else: + # dataarray + xr_pos = DataArray(idx, dims=label.dims) + results[dim] = idx + + return IndexSelResult(results) + + def equals(self, other: Self) -> bool: + return self.transform.equals(other.transform) + + def rename( + self, + name_dict: Mapping[Any, Hashable], + dims_dict: Mapping[Any, Hashable], + ) -> Self: + # TODO: maybe update self.transform coord_names, dim_size and dims attributes + return self + + def create_default_index_implicit( dim_variable: Variable, all_variables: Mapping | Iterable[Hashable] | None = None, diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 67912908a2b..35fd2597b85 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -15,6 +15,7 @@ import pandas as pd from xarray.core import duck_array_ops +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS from xarray.core.types import T_Xarray @@ -1303,6 +1304,42 @@ def _decompose_outer_indexer( return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer))) +def _posify_indices(indices: np.typing.ArrayLike, size: int) -> np.ndarray: + """Convert negative indices by their equivalent positive indices. + + Note: the resulting indices may still be out of bounds (< 0 or >= size). + + """ + return np.where(indices < 0, size + indices, indices) + + +def _check_bounds(indices, size): + """Check if the given indices are all within the array boundaries.""" + if np.any((indices < 0) | (indices >= size)): + raise IndexError("out of bounds index") + + +def _arrayize_outer_indexer(indexer: OuterIndexer, shape) -> OuterIndexer: + """Return a similar oindex with after replacing slices by arrays and + negative indices by their corresponding positive indices. + + Also check if array indices are within bounds. + + """ + new_key = [] + + for axis, value in enumerate(indexer.tuple): + size = shape[axis] + if isinstance(value, slice): + value = _expand_slice(value, size) + else: + value = _posify_indices(value, size) + _check_bounds(value, size) + new_key.append(value) + + return OuterIndexer(tuple(new_key)) + + def _arrayize_vectorized_indexer( indexer: VectorizedIndexer, shape: _Shape ) -> VectorizedIndexer: @@ -1921,3 +1958,111 @@ def copy(self, deep: bool = True) -> Self: # see PandasIndexingAdapter.copy array = self.array.copy(deep=True) if deep else self.array return type(self)(array, self._dtype, self.level) + + +class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap a CoordinateTransform to support explicit indexing and + lazy coordinate labels. + + """ + + _transform: CoordinateTransform + _coord_name: Hashable + _dims: tuple[str, ...] + + def __init__( + self, + transform: CoordinateTransform, + coord_name: Hashable, + dims: tuple[str] | None = None, + ): + self._transform = transform + self._coord_name = coord_name + self._dims = dims or transform.dims + + @property + def dtype(self) -> np.dtype: + return self._transform.dtype + + @property + def shape(self): + return tuple(self._transform.dim_size.values()) + + def get_duck_array(self) -> np.ndarray: + all_coords = self._transform.generate_coords(dims=self._dims) + return np.asarray(all_coords[self._coord_name]) + + def _oindex_get(self, indexer: OuterIndexer): + expanded_indexer_ = OuterIndexer(expanded_indexer(indexer.tuple, self.ndim)) + array_indexer = _arrayize_outer_indexer(expanded_indexer_, self.shape) + + positions = np.meshgrid(*array_indexer.tuple, indexing="ij") + dim_positions = { + dim: pos for dim, pos in zip(self._dims, positions, strict=False) + } + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]).squeeze() + + def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def _vindex_get(self, indexer: VectorizedIndexer): + expanded_indexer_ = VectorizedIndexer( + expanded_indexer(indexer.tuple, self.ndim) + ) + array_indexer = _arrayize_vectorized_indexer(expanded_indexer_, self.shape) + + dim_positions = {} + for i, (dim, pos) in enumerate( + zip(self._dims, array_indexer.tuple, strict=False) + ): + pos = _posify_indices(pos, self.shape[i]) + _check_bounds(pos, self.shape[i]) + dim_positions[dim] = pos + + result = self._transform.forward(dim_positions) + return np.asarray(result[self._coord_name]) + + def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def __getitem__(self, indexer: ExplicitIndexer): + # TODO: make it lazy (i.e., re-calculate and re-wrap the transform) when possible? + self._check_and_raise_if_non_basic_indexer(indexer) + + # also works with basic indexing + return self._oindex_get(indexer) + + def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: + raise TypeError( + "setting values is not supported on coordinate transform arrays." + ) + + def transpose(self, order): + new_dims = tuple([self._dims[i] for i in order]) + return type(self)(self._transform, self._coord_name, new_dims) + + def __repr__(self: Any) -> str: + return f"{type(self).__name__}(transform={self._transform!r})" + + def _get_array_subset(self) -> np.ndarray: + threshold = max(100, OPTIONS["display_values_threshold"] + 2) + if self.size > threshold: + pos = threshold // 2 + indices = np.concatenate([np.arange(0, pos), np.arange(-pos, 0)]) + subset = self.vindex[VectorizedIndexer((indices,) * self.ndim)] + else: + subset = self + + return np.asarray(subset) + + def _repr_inline_(self, max_width: int) -> str: + """Good to see some labels even for a lazy coordinate.""" + from xarray.core.formatting import format_array_flat + + return format_array_flat(self._get_array_subset(), max_width) From 0b545cf61cf192cd1037e2e1d312921a8ab5843c Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 20:26:15 +0200 Subject: [PATCH 2/8] lint, public API and docstrings --- xarray/__init__.py | 2 ++ xarray/core/coordinate_transform.py | 10 ++++++--- xarray/core/indexes.py | 32 ++++++++++++++++++++--------- xarray/core/indexing.py | 7 ++++--- xarray/indexes/__init__.py | 9 ++++++-- 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/xarray/__init__.py b/xarray/__init__.py index e3b7ec469e9..b49ab1848b7 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -30,6 +30,7 @@ where, ) from xarray.core.concat import concat +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset @@ -109,6 +110,7 @@ "CFTimeIndex", "Context", "Coordinates", + "CoordinateTransform", "DataArray", "Dataset", "DataTree", diff --git a/xarray/core/coordinate_transform.py b/xarray/core/coordinate_transform.py index 1d4db3e9b7e..40043da46bc 100644 --- a/xarray/core/coordinate_transform.py +++ b/xarray/core/coordinate_transform.py @@ -1,4 +1,5 @@ -from typing import Any, Iterable, Hashable, Mapping +from collections.abc import Hashable, Iterable, Mapping +from typing import Any import numpy as np @@ -15,11 +16,14 @@ def __init__( self, coord_names: Iterable[Hashable], dim_size: Mapping[str, int], - dtype: Any = np.dtype(np.float64), + dtype: Any = None, ): self.coord_names = tuple(coord_names) self.dims = tuple(dim_size) self.dim_size = dict(dim_size) + + if dtype is None: + dtype = np.dtype(np.float64) self.dtype = dtype def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]: @@ -61,7 +65,7 @@ def equals(self, other: "CoordinateTransform") -> bool: raise NotImplementedError def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]: - """Returns all "world" coordinate labels.""" + """Compute all coordinate labels at once.""" if dims is None: dims = self.dims diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 8d90c955bfe..e154b727fc5 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -12,6 +12,7 @@ from xarray.core import formatting, nputils, utils from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexing import ( + CoordinateTransformIndexingAdapter, IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter, @@ -25,6 +26,7 @@ ) if TYPE_CHECKING: + from xarray.core.coordinate import Coordinates from xarray.core.types import ErrorOptions, JoinOptions, Self from xarray.core.variable import Variable @@ -1374,8 +1376,13 @@ def rename(self, name_dict, dims_dict): class CoordinateTransformIndex(Index): - """Xarray index abstract class for transformation between "pixel" - and "world" coordinates. + """Helper class for creating Xarray indexes based on coordinate transforms. + + - wraps a :py:class:`CoordinateTransform` instance + - takes care of creating the index (lazy) coordinates + - supports point-wise label-based selection + - supports exact alignment only, by comparing indexes based on their transform + (not on their explicit coordinate labels) """ @@ -1409,9 +1416,11 @@ def create_variables( def create_coordinates(self) -> Coordinates: # TODO: move this in xarray.Index base class? + from xarray.core.coordinates import Coordinates + variables = self.create_variables() indexes = {name: self for name in variables} - return xr.Coordinates(coords=variables, indexes=indexes) + return Coordinates(coords=variables, indexes=indexes) def isel( self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] @@ -1423,6 +1432,9 @@ def isel( def sel( self, labels: dict[Any, Any], method=None, tolerance=None ) -> IndexSelResult: + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + if method != "nearest": raise ValueError( "CoordinateTransformIndex only supports selection with method='nearest'" @@ -1433,15 +1445,14 @@ def sel( missing_labels = coord_names_set - labels_set if missing_labels: - raise ValueError( - f"missing labels for coordinate(s): {','.join(missing_labels)}." - ) + missing_labels_str = ",".join([f"{name}" for name in missing_labels]) + raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.") label0_obj = next(iter(labels.values())) dim_size0 = getattr(label0_obj, "sizes", None) is_xr_obj = [ - isinstance(label, (xr.DataArray, xr.Variable)) for label in labels.values() + isinstance(label, DataArray | Variable) for label in labels.values() ] if not all(is_xr_obj): raise TypeError( @@ -1461,13 +1472,14 @@ def sel( dim_positions = self.transform.reverse(coord_labels) results = {} + dims0 = tuple(dim_size0) for dim, pos in dim_positions.items(): if isinstance(label0_obj, Variable): - xr_pos = Variable(label.dims, idx) + xr_pos = Variable(dims0, pos) else: # dataarray - xr_pos = DataArray(idx, dims=label.dims) - results[dim] = idx + xr_pos = DataArray(pos, dims=dims0) + results[dim] = xr_pos return IndexSelResult(results) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 35fd2597b85..047e16c240d 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1961,8 +1961,9 @@ def copy(self, deep: bool = True) -> Self: class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): - """Wrap a CoordinateTransform to support explicit indexing and - lazy coordinate labels. + """Wrap a CoordinateTransform as a lazy coordinate array. + + Supports explicit indexing (both outer and vectorized). """ @@ -2036,7 +2037,7 @@ def __getitem__(self, indexer: ExplicitIndexer): self._check_and_raise_if_non_basic_indexer(indexer) # also works with basic indexing - return self._oindex_get(indexer) + return self._oindex_get(OuterIndexer(indexer.tuple)) def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None: raise TypeError( diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index b1bf7a1af11..e2857b8602b 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -3,6 +3,11 @@ """ -from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex +from xarray.core.indexes import ( + CoordinateTransformIndex, + Index, + PandasIndex, + PandasMultiIndex, +) -__all__ = ["Index", "PandasIndex", "PandasMultiIndex"] +__all__ = ["CoordinateTransformIndex", "Index", "PandasIndex", "PandasMultiIndex"] From 8af6614086f8ca181ec070859fca1e019663c837 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 20:30:52 +0200 Subject: [PATCH 3/8] missing import --- xarray/core/indexes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index e154b727fc5..b56d1faf295 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1397,6 +1397,8 @@ def __init__( def create_variables( self, variables: Mapping[Any, Variable] | None = None ) -> IndexVars: + from xarray.core.variable import Variable + new_variables = {} for name in self.transform.coord_names: From e9a11ef6df072c4f61eea7ea7be00e12d7cee5da Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 20:48:25 +0200 Subject: [PATCH 4/8] sel: convert inverse transform results to ints --- xarray/core/indexes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index b56d1faf295..ab725f86833 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1476,6 +1476,7 @@ def sel( results = {} dims0 = tuple(dim_size0) for dim, pos in dim_positions.items(): + pos = np.round(pos).astype("int") if isinstance(label0_obj, Variable): xr_pos = Variable(dims0, pos) else: From 0b3fd9ee751f64b9695609a601ae31b336c1e0a0 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Tue, 24 Sep 2024 22:13:40 +0200 Subject: [PATCH 5/8] sel: add todo note about rounding decimal pos --- xarray/core/indexes.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index ab725f86833..987039e1f87 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1476,6 +1476,9 @@ def sel( results = {} dims0 = tuple(dim_size0) for dim, pos in dim_positions.items(): + # TODO: rounding the decimal positions is not always the behavior we expect + # (there are different ways to represent implicit intervals) + # we should probably make this customizable. pos = np.round(pos).astype("int") if isinstance(label0_obj, Variable): xr_pos = Variable(dims0, pos) From acf1c478c68fcadcc6bfbdd4414bc97b8667383f Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 26 Sep 2024 10:37:04 +0200 Subject: [PATCH 6/8] rename create_coordinates -> create_coords More consistent with the rest of Xarray API where `coords` is used everywhere. --- xarray/core/indexes.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 987039e1f87..4f2bba20844 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1417,6 +1417,11 @@ def create_variables( return new_variables def create_coordinates(self) -> Coordinates: + # TODO: remove this alias before merging https://github.com/pydata/xarray/pull/9543! + # (we keep it there so it doesn't break the code of those who are experimenting with this) + return self.create_coords() + + def create_coords(self) -> Coordinates: # TODO: move this in xarray.Index base class? from xarray.core.coordinates import Coordinates From e101585e9fb30a3b73d6d37a7bc0be1607f991b1 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Thu, 26 Sep 2024 10:46:13 +0200 Subject: [PATCH 7/8] add a Coordinates.from_transform convenient method --- xarray/core/coordinates.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index a6dec863aec..af622aaca8b 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -14,7 +14,9 @@ from xarray.core import formatting from xarray.core.alignment import Aligner +from xarray.core.coordinate_transform import CoordinateTransform from xarray.core.indexes import ( + CoordinateTransformIndex, Index, Indexes, PandasIndex, @@ -356,7 +358,7 @@ def _construct_direct( def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: """Wrap a pandas multi-index as Xarray coordinates (dimension + levels). - The returned coordinates can be directly assigned to a + The returned coordinate variables can be directly assigned to a :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the ``coords`` argument of their constructor. @@ -380,6 +382,28 @@ def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: Hashable) -> Self: return cls(coords=variables, indexes=indexes) + @classmethod + def from_transform(cls, transform: CoordinateTransform) -> Self: + """Wrap a coordinate transform as Xarray (lazy) coordinates. + + The returned coordinate variables can be directly assigned to a + :py:class:`~xarray.Dataset` or :py:class:`~xarray.DataArray` via the + ``coords`` argument of their constructor. + + Parameters + ---------- + transform : :py:class:`CoordinateTransform` + Xarray coordinate transform object. + + Returns + ------- + coords : Coordinates + A collection of Xarray indexed coordinates created from the transform. + + """ + index = CoordinateTransformIndex(transform) + return index.create_coords() + @property def _names(self) -> set[Hashable]: return self._data._coord_names From 09667c5da4e2de2f1db6896e3acce0205e3608e3 Mon Sep 17 00:00:00 2001 From: Benoit Bovy Date: Wed, 2 Oct 2024 14:50:55 +0200 Subject: [PATCH 8/8] fix repr (extract subset values of any n-d array) --- xarray/core/indexing.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 047e16c240d..04677bb8d60 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -2055,8 +2055,12 @@ def _get_array_subset(self) -> np.ndarray: threshold = max(100, OPTIONS["display_values_threshold"] + 2) if self.size > threshold: pos = threshold // 2 - indices = np.concatenate([np.arange(0, pos), np.arange(-pos, 0)]) - subset = self.vindex[VectorizedIndexer((indices,) * self.ndim)] + flat_indices = np.concatenate( + [np.arange(0, pos), np.arange(self.size - pos, self.size)] + ) + subset = self.vindex[ + VectorizedIndexer(np.unravel_index(flat_indices, self.shape)) + ] else: subset = self