Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MultiPandasIndex helper class #7182

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xarray/indexes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@

"""
from ..core.indexes import Index, PandasIndex, PandasMultiIndex
from .multipandasindex import MultiPandasIndex

__all__ = ["Index", "PandasIndex", "PandasMultiIndex"]
__all__ = ["Index", "MultiPandasIndex", "PandasIndex", "PandasMultiIndex"]
175 changes: 175 additions & 0 deletions xarray/indexes/multipandasindex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from __future__ import annotations

from typing import Any, Hashable, Mapping, TypeVar

import numpy as np

from ..core.indexes import Index, IndexVars, PandasIndex
from ..core.indexing import IndexSelResult, merge_sel_results
from ..core.utils import Frozen
from ..core.variable import Variable

T_MultiPandasIndex = TypeVar("T_MultiPandasIndex", bound="MultiPandasIndex")


class MultiPandasIndex(Index):
"""Helper class to implement meta-indexes encapsulating
one or more (single) pandas indexes.

Each pandas index must relate to a separate dimension.

This class shoudn't be instantiated directly.

"""

indexes: Frozen[Hashable, PandasIndex]
dims: Frozen[Hashable, int]

__slots__ = ("indexes", "dims")

def __init__(self, indexes: Mapping[Hashable, PandasIndex]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always better to use Any for Mapping inputs, due to covariance.
Use Mapping[Any, PandasIndex]

dims = {idx.dim: idx.index.size for idx in indexes.values()}

seen = set()
dup_dims = []
for d in dims:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since dims is a dict, shouldn't the keys be unique? I don't think you will have repetitions here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, sure 🙂. But we should probably check for duplicate dimensions so a dict is not what we want for dims.

Copy link
Collaborator

@headtr1ck headtr1ck Oct 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could just do

dim_names = [idx.dim for idk in indexes]
if len(dim_names) != len(set(dim_names)): raise...

if d in seen:
dup_dims.append(d)
else:
seen.add(d)

if dup_dims:
raise ValueError(
f"cannot create a {self.__class__.__name__} from coordinates "
f"sharing common dimension(s): {dup_dims}"
)

self.indexes = Frozen(indexes)
self.dims = Frozen(dims)

@classmethod
def from_variables(
cls: type[T_MultiPandasIndex], variables: Mapping[Any, Variable], options
) -> T_MultiPandasIndex:
indexes = {
k: PandasIndex.from_variables({k: v}, options={})
for k, v in variables.items()
}

return cls(indexes)

def create_variables(
self, variables: Mapping[Any, Variable] | None = None
) -> IndexVars:

idx_variables = {}

for idx in self.indexes.values():
idx_variables.update(idx.create_variables(variables))

return idx_variables

def isel(
self: T_MultiPandasIndex,
indexers: Mapping[Any, int | slice | np.ndarray | Variable],
) -> T_MultiPandasIndex | PandasIndex | None:
new_indexes = {}

for k, idx in self.indexes.items():
if k in indexers:
new_idx = idx.isel({k: indexers[k]})
if new_idx is not None:
new_indexes[k] = new_idx
else:
new_indexes[k] = idx

#
# How should we deal with dropped index(es) (scalar selection)?
# - drop the whole index?
# - always return a MultiPandasIndex with remaining index(es)?
# - return either a MultiPandasIndex or a PandasIndex?
#

if not len(new_indexes):
return None
elif len(new_indexes) == 1:
return next(iter(new_indexes.values()))
else:
return type(self)(new_indexes)

def sel(self, labels: dict[Any, Any], **kwargs) -> IndexSelResult:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Mapping[Any, Any]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method should be called internally in Xarray (always with a dict), but maybe there are some cases where one wants to call it directly so Mapping[Any, Any] would be better..

results: list[IndexSelResult] = []

for k, idx in self.indexes.items():
if k in labels:
results.append(idx.sel({k: labels[k]}, **kwargs))

return merge_sel_results(results)

def _get_unmatched_names(
self: T_MultiPandasIndex, other: T_MultiPandasIndex
) -> set:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use set[Hashable]

return set(self.indexes).symmetric_difference(other.indexes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always find the operator notation easier to read: set(self.indexes) ^ set(other.indexes)


def equals(self: T_MultiPandasIndex, other: T_MultiPandasIndex) -> bool:
# We probably don't need to check for matching coordinate names
# as this is already done during alignment when finding matching indexes.
# This may change in the future, though.
# see https://github.com/pydata/xarray/issues/7002
if self._get_unmatched_names(other):
return False
else:
return all(
[idx.equals(other.indexes[k]) for k, idx in self.indexes.items()]
)

def join(
self: T_MultiPandasIndex, other: T_MultiPandasIndex, how: str = "inner"
) -> T_MultiPandasIndex:
new_indexes = {}

for k, idx in self.indexes.items():
new_indexes[k] = idx.join(other.indexes[k], how=how)

return type(self)(new_indexes)

def reindex_like(
self: T_MultiPandasIndex, other: T_MultiPandasIndex
) -> dict[Hashable, Any]:
dim_indexers = {}

for k, idx in self.indexes.items():
dim_indexers.update(idx.reindex_like(other.indexes[k]))

return dim_indexers

def roll(self: T_MultiPandasIndex, shifts: Mapping[Any, int]) -> T_MultiPandasIndex:
new_indexes = {}

for k, idx in self.indexes.items():
if k in shifts:
new_indexes[k] = idx.roll({k: shifts[k]})
else:
new_indexes[k] = idx

return type(self)(new_indexes)

def rename(
self: T_MultiPandasIndex,
name_dict: Mapping[Any, Hashable],
dims_dict: Mapping[Any, Hashable],
) -> T_MultiPandasIndex:
new_indexes = {}

for k, idx in self.indexes.items():
new_indexes[k] = idx.rename(name_dict, dims_dict)

return type(self)(new_indexes)

def copy(self: T_MultiPandasIndex, deep: bool = True) -> T_MultiPandasIndex:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be updated following #7140.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well implement the correct deep-copy behavior which passes the memo dict.

new_indexes = {}

for k, idx in self.indexes.items():
new_indexes[k] = idx.copy(deep=deep)

return type(self)(new_indexes)