Skip to content

Commit

Permalink
Refactor index vs. coordinate variable(s) (#5636)
Browse files Browse the repository at this point in the history
* split index / coordinate variable(s)

- Pass Variable objects to xarray.Index constructor
- The index should create IndexVariable objects (`coords` attribute)
- PandasIndex: IndexVariable wraps PandasIndexingAdpater wraps pd.Index

* one PandasIndexingAdapter subclass for multiindex

* fastpath Index init + from_pandas_index classmethods

* use classmethod constructors instead

* add Index.copy and Index.__getitem__ methods

* wip: clean-up

Revert some changes made in #5102 + additional (temporary) fixes.

* clean-up

* add PandasIndex and PandasMultiIndex tests

* remove unused import

* doc: update what's new

* use xindexes in map_blocks + temp fix

Dataset constructor doesn't accept xarray indexes yet. Create new
coordinates from the underlying pandas indexes.

* update what's new with #5670

* typo
  • Loading branch information
benbovy authored Aug 9, 2021
1 parent 08b3e80 commit 4bb9d9c
Show file tree
Hide file tree
Showing 17 changed files with 608 additions and 282 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ Documentation
Internal Changes
~~~~~~~~~~~~~~~~

- Explicit indexes refactor: avoid ``len(index)`` in ``map_blocks`` (:pull:`5670`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Explicit indexes refactor: decouple ``xarray.Index``` from ``xarray.Variable`` (:pull:`5636`).
By `Benoit Bovy <https://github.com/benbovy>`_.
- Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`)
By `Jimmy Westling <https://github.com/illviljan>`_.

Expand Down
46 changes: 30 additions & 16 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pandas as pd

from . import dtypes
from .indexes import Index, PandasIndex, get_indexer_nd, wrap_pandas_index
from .indexes import Index, PandasIndex, get_indexer_nd
from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index
from .variable import IndexVariable, Variable

Expand Down Expand Up @@ -53,7 +53,10 @@ def _get_joiner(join, index_cls):
def _override_indexes(objects, all_indexes, exclude):
for dim, dim_indexes in all_indexes.items():
if dim not in exclude:
lengths = {index.size for index in dim_indexes}
lengths = {
getattr(index, "size", index.to_pandas_index().size)
for index in dim_indexes
}
if len(lengths) != 1:
raise ValueError(
f"Indexes along dimension {dim!r} don't have the same length."
Expand Down Expand Up @@ -300,16 +303,14 @@ def align(
joined_indexes = {}
for dim, matching_indexes in all_indexes.items():
if dim in indexes:
# TODO: benbovy - flexible indexes. maybe move this logic in util func
if isinstance(indexes[dim], Index):
index = indexes[dim]
else:
index = PandasIndex(safe_cast_to_index(indexes[dim]))
index, _ = PandasIndex.from_pandas_index(
safe_cast_to_index(indexes[dim]), dim
)
if (
any(not index.equals(other) for other in matching_indexes)
or dim in unlabeled_dim_sizes
):
joined_indexes[dim] = index
joined_indexes[dim] = indexes[dim]
else:
if (
any(
Expand All @@ -323,17 +324,18 @@ def align(
joiner = _get_joiner(join, type(matching_indexes[0]))
index = joiner(matching_indexes)
# make sure str coords are not cast to object
index = maybe_coerce_to_str(index, all_coords[dim])
index = maybe_coerce_to_str(index.to_pandas_index(), all_coords[dim])
joined_indexes[dim] = index
else:
index = all_coords[dim][0]

if dim in unlabeled_dim_sizes:
unlabeled_sizes = unlabeled_dim_sizes[dim]
# TODO: benbovy - flexible indexes: expose a size property for xarray.Index?
# Some indexes may not have a defined size (e.g., built from multiple coords of
# different sizes)
labeled_size = index.size
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
if isinstance(index, PandasIndex):
labeled_size = index.to_pandas_index().size
else:
labeled_size = index.size
if len(unlabeled_sizes | {labeled_size}) > 1:
raise ValueError(
f"arguments without labels along dimension {dim!r} cannot be "
Expand All @@ -350,7 +352,14 @@ def align(

result = []
for obj in objects:
valid_indexers = {k: v for k, v in joined_indexes.items() if k in obj.dims}
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
valid_indexers = {}
for k, index in joined_indexes.items():
if k in obj.dims:
if isinstance(index, Index):
valid_indexers[k] = index.to_pandas_index()
else:
valid_indexers[k] = index
if not valid_indexers:
# fast path for no reindexing necessary
new_obj = obj.copy(deep=copy)
Expand Down Expand Up @@ -471,7 +480,11 @@ def reindex_like_indexers(
ValueError
If any dimensions without labels have different sizes.
"""
indexers = {k: v for k, v in other.xindexes.items() if k in target.dims}
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647
# this doesn't support yet indexes other than pd.Index
indexers = {
k: v.to_pandas_index() for k, v in other.xindexes.items() if k in target.dims
}

for dim in other.dims:
if dim not in indexers and dim in target.dims:
Expand Down Expand Up @@ -560,7 +573,8 @@ def reindex_variables(
"from that to be indexed along {:s}".format(str(indexer.dims), dim)
)

target = new_indexes[dim] = wrap_pandas_index(safe_cast_to_index(indexers[dim]))
target = safe_cast_to_index(indexers[dim])
new_indexes[dim] = PandasIndex(target, dim)

if dim in indexes:
# TODO (benbovy - flexible indexes): support other indexes than pd.Index?
Expand Down
5 changes: 2 additions & 3 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def _infer_concat_order_from_coords(datasets):
"inferring concatenation order"
)

# TODO (benbovy, flexible indexes): all indexes should be Pandas.Index
# get pd.Index objects from Index objects
indexes = [index.array for index in indexes]
# TODO (benbovy, flexible indexes): support flexible indexes?
indexes = [index.to_pandas_index() for index in indexes]

# If dimension coordinate values are same on every dataset then
# should be leaving this dimension alone (it's just a "bystander")
Expand Down
22 changes: 5 additions & 17 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,7 @@
)
from .dataset import Dataset, split_indexes
from .formatting import format_item
from .indexes import (
Index,
Indexes,
default_indexes,
propagate_indexes,
wrap_pandas_index,
)
from .indexes import Index, Indexes, default_indexes, propagate_indexes
from .indexing import is_fancy_indexer
from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords
from .options import OPTIONS, _get_keep_attrs
Expand Down Expand Up @@ -473,15 +467,14 @@ def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray":
return self
coords = self._coords.copy()
for name, idx in indexes.items():
coords[name] = IndexVariable(name, idx)
coords[name] = IndexVariable(name, idx.to_pandas_index())
obj = self._replace(coords=coords)

# switch from dimension to level names, if necessary
dim_names: Dict[Any, str] = {}
for dim, idx in indexes.items():
# TODO: benbovy - flexible indexes: update when MultiIndex has its own class
pd_idx = idx.array
if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim:
pd_idx = idx.to_pandas_index()
if not isinstance(idx, pd.MultiIndex) and pd_idx.name != dim:
dim_names[dim] = idx.name
if dim_names:
obj = obj.rename(dim_names)
Expand Down Expand Up @@ -1046,12 +1039,7 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
if self._indexes is None:
indexes = self._indexes
else:
# TODO: benbovy: flexible indexes: support all xarray indexes (not just pandas.Index)
# xarray Index needs a copy method.
indexes = {
k: wrap_pandas_index(v.to_pandas_index().copy(deep=deep))
for k, v in self._indexes.items()
}
indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()}
return self._replace(variable, coords, indexes=indexes)

def __copy__(self) -> "DataArray":
Expand Down
52 changes: 32 additions & 20 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
propagate_indexes,
remove_unused_levels_categories,
roll_index,
wrap_pandas_index,
)
from .indexing import is_fancy_indexer
from .merge import (
Expand Down Expand Up @@ -1184,7 +1183,7 @@ def _overwrite_indexes(self, indexes: Mapping[Any, Index]) -> "Dataset":
variables = self._variables.copy()
new_indexes = dict(self.xindexes)
for name, idx in indexes.items():
variables[name] = IndexVariable(name, idx)
variables[name] = IndexVariable(name, idx.to_pandas_index())
new_indexes[name] = idx
obj = self._replace(variables, indexes=new_indexes)

Expand Down Expand Up @@ -2474,6 +2473,10 @@ def sel(
pos_indexers, new_indexes = remap_label_indexers(
self, indexers=indexers, method=method, tolerance=tolerance
)
# TODO: benbovy - flexible indexes: also use variables returned by Index.query
# (temporary dirty fix).
new_indexes = {k: v[0] for k, v in new_indexes.items()}

result = self.isel(indexers=pos_indexers, drop=drop)
return result._overwrite_indexes(new_indexes)

Expand Down Expand Up @@ -3297,20 +3300,21 @@ def _rename_dims(self, name_dict):
return {name_dict.get(k, k): v for k, v in self.dims.items()}

def _rename_indexes(self, name_dict, dims_set):
# TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5645
if self._indexes is None:
return None
indexes = {}
for k, v in self.xindexes.items():
# TODO: benbovy - flexible indexes: make it compatible with any xarray Index
index = v.to_pandas_index()
for k, v in self.indexes.items():
new_name = name_dict.get(k, k)
if new_name not in dims_set:
continue
if isinstance(index, pd.MultiIndex):
new_names = [name_dict.get(k, k) for k in index.names]
indexes[new_name] = PandasMultiIndex(index.rename(names=new_names))
if isinstance(v, pd.MultiIndex):
new_names = [name_dict.get(k, k) for k in v.names]
indexes[new_name] = PandasMultiIndex(
v.rename(names=new_names), new_name
)
else:
indexes[new_name] = PandasIndex(index.rename(new_name))
indexes[new_name] = PandasIndex(v.rename(new_name), new_name)
return indexes

def _rename_all(self, name_dict, dims_dict):
Expand Down Expand Up @@ -3539,7 +3543,10 @@ def swap_dims(
if new_index.nlevels == 1:
# make sure index name matches dimension name
new_index = new_index.rename(k)
indexes[k] = wrap_pandas_index(new_index)
if isinstance(new_index, pd.MultiIndex):
indexes[k] = PandasMultiIndex(new_index, k)
else:
indexes[k] = PandasIndex(new_index, k)
else:
var = v.to_base_variable()
var.dims = dims
Expand Down Expand Up @@ -3812,7 +3819,7 @@ def reorder_levels(
raise ValueError(f"coordinate {dim} has no MultiIndex")
new_index = index.reorder_levels(order)
variables[dim] = IndexVariable(coord.dims, new_index)
indexes[dim] = PandasMultiIndex(new_index)
indexes[dim] = PandasMultiIndex(new_index, dim)

return self._replace(variables, indexes=indexes)

Expand Down Expand Up @@ -3840,7 +3847,7 @@ def _stack_once(self, dims, new_dim):
coord_names = set(self._coord_names) - set(dims) | {new_dim}

indexes = {k: v for k, v in self.xindexes.items() if k not in dims}
indexes[new_dim] = wrap_pandas_index(idx)
indexes[new_dim] = PandasMultiIndex(idx, new_dim)

return self._replace_with_new_dims(
variables, coord_names=coord_names, indexes=indexes
Expand Down Expand Up @@ -4029,8 +4036,9 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
variables[name] = var

for name, lev in zip(index.names, index.levels):
variables[name] = IndexVariable(name, lev)
indexes[name] = PandasIndex(lev)
idx, idx_vars = PandasIndex.from_pandas_index(lev, name)
variables[name] = idx_vars[name]
indexes[name] = idx

coord_names = set(self._coord_names) - {dim} | set(index.names)

Expand Down Expand Up @@ -4068,8 +4076,9 @@ def _unstack_full_reindex(
variables[name] = var

for name, lev in zip(new_dim_names, index.levels):
variables[name] = IndexVariable(name, lev)
indexes[name] = PandasIndex(lev)
idx, idx_vars = PandasIndex.from_pandas_index(lev, name)
variables[name] = idx_vars[name]
indexes[name] = idx

coord_names = set(self._coord_names) - {dim} | set(new_dim_names)

Expand Down Expand Up @@ -5839,10 +5848,13 @@ def diff(self, dim, n=1, label="upper"):

indexes = dict(self.xindexes)
if dim in indexes:
# TODO: benbovy - flexible indexes: check slicing of xarray indexes?
# or only allow this for pandas indexes?
index = indexes[dim].to_pandas_index()
indexes[dim] = PandasIndex(index[kwargs_new[dim]])
if isinstance(indexes[dim], PandasIndex):
# maybe optimize? (pandas index already indexed above with var.isel)
new_index = indexes[dim].index[kwargs_new[dim]]
if isinstance(new_index, pd.MultiIndex):
indexes[dim] = PandasMultiIndex(new_index, dim)
else:
indexes[dim] = PandasIndex(new_index, dim)

difference = self._replace_with_new_dims(variables, indexes=indexes)

Expand Down
Loading

0 comments on commit 4bb9d9c

Please sign in to comment.