Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexible coordinate transform #9543

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
where,
)
from xarray.core.concat import concat
from xarray.core.coordinate_transform import CoordinateTransform
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
Expand Down Expand Up @@ -109,6 +110,7 @@
"CFTimeIndex",
"Context",
"Coordinates",
"CoordinateTransform",
"DataArray",
"Dataset",
"DataTree",
Expand Down
78 changes: 78 additions & 0 deletions xarray/core/coordinate_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from collections.abc import Hashable, Iterable, Mapping
from typing import Any

import numpy as np


class CoordinateTransform:
"""Abstract coordinate transform with dimension & coordinate names."""

coord_names: tuple[Hashable, ...]
dims: tuple[str, ...]
dim_size: dict[str, int]
dtype: Any

def __init__(
self,
coord_names: Iterable[Hashable],
dim_size: Mapping[str, int],
dtype: Any = None,
):
self.coord_names = tuple(coord_names)
self.dims = tuple(dim_size)
self.dim_size = dict(dim_size)

if dtype is None:
dtype = np.dtype(np.float64)
self.dtype = dtype

def forward(self, dim_positions: dict[str, Any]) -> dict[Hashable, Any]:
"""Perform grid -> world coordinate transformation.

Parameters
----------
dim_positions : dict
Grid location(s) along each dimension (axis).

Returns
-------
coord_labels : dict
World coordinate labels.

"""
# TODO: cache the results in order to avoid re-computing
# all labels when accessing the values of each coordinate one at a time
raise NotImplementedError

def reverse(self, coord_labels: dict[Hashable, Any]) -> dict[str, Any]:
"""Perform world -> grid coordinate reverse transformation.

Parameters
----------
labels : dict
World coordinate labels.

Returns
-------
dim_positions : dict
Grid relative location(s) along each dimension (axis).

"""
raise NotImplementedError

def equals(self, other: "CoordinateTransform") -> bool:
"""Check equality with another CoordinateTransform of the same kind."""
raise NotImplementedError

def generate_coords(self, dims: tuple[str] | None = None) -> dict[Hashable, Any]:
"""Compute all coordinate labels at once."""
if dims is None:
dims = self.dims

positions = np.meshgrid(
*[np.arange(self.dim_size[d]) for d in dims],
indexing="ij",
)
dim_positions = {dim: positions[i] for i, dim in enumerate(dims)}

return self.forward(dim_positions)
129 changes: 129 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import pandas as pd

from xarray.core import formatting, nputils, utils
from xarray.core.coordinate_transform import CoordinateTransform
from xarray.core.indexing import (
CoordinateTransformIndexingAdapter,
IndexSelResult,
PandasIndexingAdapter,
PandasMultiIndexingAdapter,
Expand All @@ -24,6 +26,7 @@
)

if TYPE_CHECKING:
from xarray.core.coordinate import Coordinates
from xarray.core.types import ErrorOptions, JoinOptions, Self
from xarray.core.variable import Variable

Expand Down Expand Up @@ -1372,6 +1375,132 @@ def rename(self, name_dict, dims_dict):
)


class CoordinateTransformIndex(Index):
"""Helper class for creating Xarray indexes based on coordinate transforms.

- wraps a :py:class:`CoordinateTransform` instance
- takes care of creating the index (lazy) coordinates
- supports point-wise label-based selection
- supports exact alignment only, by comparing indexes based on their transform
(not on their explicit coordinate labels)

"""

transform: CoordinateTransform

def __init__(
self,
transform: CoordinateTransform,
):
self.transform = transform

def create_variables(
self, variables: Mapping[Any, Variable] | None = None
) -> IndexVars:
from xarray.core.variable import Variable

new_variables = {}

for name in self.transform.coord_names:
# copy attributes, if any
attrs: Mapping[Hashable, Any] | None

if variables is not None and name in variables:
var = variables[name]
attrs = var.attrs
else:
attrs = None

data = CoordinateTransformIndexingAdapter(self.transform, name)
new_variables[name] = Variable(self.transform.dims, data, attrs=attrs)

return new_variables

def create_coordinates(self) -> Coordinates:
# TODO: move this in xarray.Index base class?
from xarray.core.coordinates import Coordinates

variables = self.create_variables()
indexes = {name: self for name in variables}
return Coordinates(coords=variables, indexes=indexes)

def isel(
self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
) -> Self | None:
# TODO: support returning a new index (e.g., possible to re-calculate the
# the transform or calculate another transform on a reduced dimension space)
return None

def sel(
self, labels: dict[Any, Any], method=None, tolerance=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would it be to support tolerance in some form? This is a common and useful form of error checking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty tricky to support it here I think, probably better to handle it on a per case basis.

For basic transformations I guess it could be possible to calculate a single, uniform tolerance value in decimal array index units and validate the selected elements using those units (cheap). In other cases we would need to compute the forward transformation of the extracted array indices and then validate the selected elements based on distances in physical units (more expensive).

Also, there may be cases where the coordinates of a same transform object don’t have all the same physical units (e.g., both degrees and radians coordinates in an Astropy WCS object). Unless we forbid that in xarray.CoordinateTransform, it doesn’t make much sense to pass a single tolerance value. Passing a dictionary tolerance={coord_name: value} doesn’t look very nice either IMO. A {unit: value} dict looks better but adding explicit support for units here might be opening a can of worms.

) -> IndexSelResult:
from xarray.core.dataarray import DataArray
from xarray.core.variable import Variable

if method != "nearest":
raise ValueError(
"CoordinateTransformIndex only supports selection with method='nearest'"
)

labels_set = set(labels)
coord_names_set = set(self.transform.coord_names)

missing_labels = coord_names_set - labels_set
if missing_labels:
missing_labels_str = ",".join([f"{name}" for name in missing_labels])
raise ValueError(f"missing labels for coordinate(s): {missing_labels_str}.")

label0_obj = next(iter(labels.values()))
dim_size0 = getattr(label0_obj, "sizes", None)

is_xr_obj = [
isinstance(label, DataArray | Variable) for label in labels.values()
]
if not all(is_xr_obj):
raise TypeError(
"CoordinateTransformIndex only supports advanced (point-wise) indexing "
"with either xarray.DataArray or xarray.Variable objects."
)
dim_size = [getattr(label, "sizes", None) for label in labels.values()]
if any([ds != dim_size0 for ds in dim_size]):
raise ValueError(
"CoordinateTransformIndex only supports advanced (point-wise) indexing "
"with xarray.DataArray or xarray.Variable objects of macthing dimensions."
)

coord_labels = {
name: labels[name].values for name in self.transform.coord_names
}
dim_positions = self.transform.reverse(coord_labels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to guarantee that out of bounds indexing raises an informative error, rather than silently attempting to access invalid data (or indexing from the end instead of the start of arrays).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still need to add unit tests but bound checking is done when indexing the coordinate variables in CoordinateTransformIndexingAdapter.


results = {}
dims0 = tuple(dim_size0)
for dim, pos in dim_positions.items():
# TODO: rounding the decimal positions is not always the behavior we expect
# (there are different ways to represent implicit intervals)
# we should probably make this customizable.
pos = np.round(pos).astype("int")
Comment on lines +1484 to +1487
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is important I think.

If the coordinates values correspond to the physical values at the top/left pixel corners in the 2D case, we may rather want np.floor(pos).astype("int") when converting decimal positions (obtained by inverse transformation) to integer indexers.

if isinstance(label0_obj, Variable):
xr_pos = Variable(dims0, pos)
else:
# dataarray
xr_pos = DataArray(pos, dims=dims0)
results[dim] = xr_pos

return IndexSelResult(results)

def equals(self, other: Self) -> bool:
return self.transform.equals(other.transform)

def rename(
self,
name_dict: Mapping[Any, Hashable],
dims_dict: Mapping[Any, Hashable],
) -> Self:
# TODO: maybe update self.transform coord_names, dim_size and dims attributes
return self


def create_default_index_implicit(
dim_variable: Variable,
all_variables: Mapping | Iterable[Hashable] | None = None,
Expand Down
Loading
Loading