Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alignment: allow flexible index coordinate order #8111

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ v2023.08.1 (unreleased)
New Features
~~~~~~~~~~~~

- It is now possible to provide custom, multi-coordinate Xarray indexes that can
be compared or aligned together regardless of the order of their coordinates.
Two "alignable" index objects must still be of the same type and have the same
set of coordinate names and dimensions. There is no change for
``PandasMultiIndex`` objects, though: they can be aligned only if their level
names and order match (:pull:`8111`).
By `Benoît Bovy <https://github.com/benbovy>`_.


Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def reindex_variables(
return new_variables


CoordNamesAndDims = tuple[tuple[Hashable, tuple[Hashable, ...]], ...]
MatchingIndexKey = tuple[CoordNamesAndDims, type[Index]]
SortedCoordNamesAndDims = tuple[tuple[Hashable, tuple[Hashable, ...]], ...]
MatchingIndexKey = tuple[SortedCoordNamesAndDims, type[Index]]
NormalizedIndexes = dict[MatchingIndexKey, Index]
NormalizedIndexVars = dict[MatchingIndexKey, dict[Hashable, Variable]]

Expand Down Expand Up @@ -227,6 +227,10 @@ def _normalize_indexes(
f"{incl_dims_str}"
)

# sort by coordinate name so that finding matching indexes
# doesn't rely on coordinate order
coord_names_and_dims.sort(key=lambda i: str(i[0]))

key = (tuple(coord_names_and_dims), type(idx))
normalized_indexes[key] = idx
normalized_index_vars[key] = index_vars
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,8 +815,8 @@ def diff_coords_repr(a, b, compat, col_width=None):
"Coordinates",
summarize_variable,
col_width=col_width,
a_indexes=a.indexes,
b_indexes=b.indexes,
a_indexes=a.xindexes,
b_indexes=b.xindexes,
)


Expand Down
13 changes: 13 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,20 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult:
else:
return IndexSelResult({self.dim: indexer})

def equals(self, other: Index):
is_equal = super().equals(other)
if is_equal and isinstance(other, PandasMultiIndex):
is_equal = self.index.names == other.index.names
return is_equal

def join(self, other, how: str = "inner"):
if other.index.names != self.index.names:
raise ValueError(
f"cannot join together a PandasMultiIndex with levels {tuple(self.index.names)!r} and "
f"another PandasMultiIndex with levels {tuple(other.index.names)!r} "
"(level order mismatch)."
)

if how == "outer":
# bug in pandas? need to reset index.name
other_index = other.index.copy()
Expand Down
67 changes: 67 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,6 +2439,73 @@ def test_align_index_var_attrs(self, join) -> None:
assert ds.x.attrs == {"units": "m"}
assert ds_noattr.x.attrs == {}

def test_align_custom_index_no_coord_order(self) -> None:
class CustomIndex(Index):
"""A meta-index wrapping a dict of PandasIndex objects where the
order of the coordinares doesn't matter.
"""

def __init__(self, indexes: dict[Hashable, PandasIndex]):
self._indexes = indexes

@classmethod
def from_variables(cls, variables, *, options):
indexes = {
k: PandasIndex.from_variables({k: v}, options=options)
for k, v in variables.items()
}
return cls(indexes)

def create_variables(self, variables=None):
if variables is None:
variables = {}
idx_vars = {}
for k, v in variables.items():
idx_vars.update(self._indexes[k].create_variables({k: v}))
return idx_vars

def equals(self, other: CustomIndex):
return all(
[self._indexes[k].equals(other._indexes[k]) for k in self._indexes]
)

def join(self, other: CustomIndex, how="inner"):
indexes = {
k: self._indexes[k].join(other._indexes[k], how=how)
for k in self._indexes
}
return CustomIndex(indexes)

def reindex_like(self, other, method=None, tolerance=None):
result = {}
for k, idx in self._indexes.items():
result.update(
idx.reindex_like(
other._indexes[k], method=method, tolerance=tolerance
)
)
return result

ds1 = (
Dataset(coords={"x": [1, 2], "y": [1, 2, 3, 4]})
.drop_indexes(["x", "y"])
.set_xindex(["x", "y"], CustomIndex)
)
ds2 = (
Dataset(coords={"y": [3, 4, 5, 6], "x": [1, 2]})
.drop_indexes(["x", "y"])
.set_xindex(["y", "x"], CustomIndex)
)
expected = (
Dataset(coords={"x": [1, 2], "y": [3, 4]})
.drop_indexes(["x", "y"])
.set_xindex(["x", "y"], CustomIndex)
)

actual1, actual2 = xr.align(ds1, ds2, join="inner")
assert_identical(actual1, expected, check_default_indexes=False)
assert_identical(actual2, expected, check_default_indexes=False)

def test_broadcast(self) -> None:
ds = Dataset(
{"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])}
Expand Down
17 changes: 16 additions & 1 deletion xarray/tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def test_sel(self) -> None:
with pytest.raises(IndexError):
index.sel({"x": (slice(None), 1, "no_level")})

def test_join(self):
def test_join(self) -> None:
midx = pd.MultiIndex.from_product([["a", "aa"], [1, 2]], names=("one", "two"))
level_coords_dtype = {"one": "=U2", "two": "i"}
index1 = PandasMultiIndex(midx, "x", level_coords_dtype=level_coords_dtype)
Expand All @@ -501,6 +501,21 @@ def test_join(self):
assert actual.equals(index1)
assert actual.level_coords_dtype == level_coords_dtype

def test_swap_index_levels(self) -> None:
# when the order of level names down't match
# - equals should return False
# - join should fail
# TODO: remove when fixed upstream
midx1 = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("one", "two"))
idx1 = PandasMultiIndex(midx1, "x")
midx2 = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=("two", "one"))
idx2 = PandasMultiIndex(midx2, "x")

assert idx1.equals(idx2) is False

with pytest.raises(ValueError, match=".*level order mismatch"):
idx1.join(idx2)

def test_rename(self) -> None:
level_coords_dtype = {"one": "<U1", "two": np.int32}
index = PandasMultiIndex(
Expand Down