Skip to content

Commit

Permalink
refactor reorder_levels
Browse files Browse the repository at this point in the history
  • Loading branch information
benbovy committed Sep 14, 2021
1 parent a891e22 commit e50978e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 48 deletions.
13 changes: 2 additions & 11 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,17 +2042,8 @@ def reorder_levels(
Another dataarray, with this dataarray's data but replaced
coordinates.
"""
dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels")
replace_coords = {}
for dim, order in dim_order.items():
coord = self._coords[dim]
index = coord.to_index()
if not isinstance(index, pd.MultiIndex):
raise ValueError(f"coordinate {dim!r} has no MultiIndex")
replace_coords[dim] = IndexVariable(coord.dims, index.reorder_levels(order))
coords = self._coords.copy()
coords.update(replace_coords)
return self._replace(coords=coords)
ds = self._to_temp_dataset().reorder_levels(dim_order, **dim_order_kwargs)
return self._from_temp_dataset(ds)

def stack(
self,
Expand Down
31 changes: 20 additions & 11 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3638,11 +3638,11 @@ def set_index(
current_coord_names = index_coord_names.get(dim, [])

# drop any pre-existing index involved
maybe_drop_indexes.extend(current_coord_names + var_names)
maybe_drop_indexes += current_coord_names + var_names
for k in var_names:
maybe_drop_indexes.extend(index_coord_names.get(k, []))
maybe_drop_indexes += index_coord_names.get(k, [])

drop_variables.extend(var_names)
drop_variables += var_names

if len(var_names) == 1 and (not append or dim not in self.xindexes):
var_name = var_names[0]
Expand Down Expand Up @@ -3800,16 +3800,25 @@ def reorder_levels(
dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels")
variables = self._variables.copy()
indexes = dict(self.xindexes)
new_indexes: Dict[Hashable, Index] = {}
new_variables: Dict[Hashable, IndexVariable] = {}

for dim, order in dim_order.items():
coord = self._variables[dim]
# TODO: benbovy - flexible indexes: update when MultiIndex
# has its own class inherited from xarray.Index
index = self.xindexes[dim].to_pandas_index()
if not isinstance(index, pd.MultiIndex):
index = self.xindexes[dim]

if not isinstance(index, PandasMultiIndex):
raise ValueError(f"coordinate {dim} has no MultiIndex")
new_index = index.reorder_levels(order)
variables[dim] = IndexVariable(coord.dims, new_index)
indexes[dim] = PandasMultiIndex(new_index, dim)

idx, idx_vars = index.reorder_levels({k: self._variables[k] for k in order})

new_variables.update(idx_vars)
new_indexes.update({k: idx for k in idx_vars})

indexes = {k: v for k, v in self.xindexes.items() if k not in new_indexes}
indexes.update(new_indexes)

variables = {k: v for k, v in self._variables.items() if k not in new_variables}
variables.update(new_variables)

return self._replace(variables, indexes=indexes)

Expand Down
56 changes: 30 additions & 26 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,13 @@ def _check_dim_compat(variables: Mapping[Any, "Variable"]) -> Hashable:
return next(iter(dims))[0]


def _get_var_metadata(variables: Mapping[Any, "Variable"]) -> Dict[Any, Dict[str, Any]]:
return {
name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding}
for name, var in variables.items()
}


def _create_variables_from_multiindex(index, dim, var_meta=None):
from .variable import IndexVariable

Expand Down Expand Up @@ -406,11 +413,9 @@ def from_variables(
level_coords_dtype = {name: var.dtype for name, var in variables.items()}
obj = cls(index, dim, level_coords_dtype=level_coords_dtype)

var_meta = {
name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding}
for name, var in variables.items()
}
index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta)
index_vars = _create_variables_from_multiindex(
index, dim, var_meta=_get_var_metadata(variables)
)

return obj, index_vars

Expand All @@ -429,17 +434,10 @@ def from_variables_maybe_expand(
names: List[Hashable] = []
codes: List[List[int]] = []
levels: List[List[int]] = []
var_meta: Dict[str, Dict] = {}
level_variables: Dict[Any, "Variable"] = {}

_check_dim_compat({**current_variables, **variables})

def add_level_var(name, var):
var_meta[name] = {
"dtype": var.dtype,
"attrs": var.attrs,
"encoding": var.encoding,
}

if len(current_variables) > 1:
# expand from an existing multi-index
data = cast(
Expand All @@ -450,7 +448,7 @@ def add_level_var(name, var):
codes.extend(current_index.codes)
levels.extend(current_index.levels)
for name in current_index.names:
add_level_var(name, current_variables[name])
level_variables[name] = current_variables[name]

elif len(current_variables) == 1:
# expand from one 1D variable (no multi-index): convert it to an index level
Expand All @@ -460,18 +458,20 @@ def add_level_var(name, var):
cat = pd.Categorical(var.values, ordered=True)
codes.append(cat.codes)
levels.append(cat.categories)
add_level_var(new_var_name, var)
level_variables[new_var_name] = var

for name, var in variables.items():
names.append(name)
cat = pd.Categorical(var.values, ordered=True)
codes.append(cat.codes)
levels.append(cat.categories)
add_level_var(name, var)
level_variables[name] = var

index = pd.MultiIndex(levels, codes, names=names)

return cls.from_pandas_index(index, dim, var_meta=var_meta)
return cls.from_pandas_index(
index, dim, var_meta=_get_var_metadata(level_variables)
)

def keep_levels(
self, level_variables: Mapping[Any, "Variable"]
Expand All @@ -480,15 +480,7 @@ def keep_levels(
corresponding coordinates.
"""
var_meta: Dict[str, Dict] = {}

for name, var in level_variables.items():
var_meta[name] = {
"dtype": var.dtype,
"attrs": var.attrs,
"encoding": var.encoding,
}

var_meta = _get_var_metadata(level_variables)
index = self.index.droplevel(
[k for k in self.index.names if k not in level_variables]
)
Expand All @@ -498,6 +490,18 @@ def keep_levels(
else:
return PandasIndex.from_pandas_index(index, self.dim, var_meta=var_meta)

def reorder_levels(
self, level_variables: Mapping[Any, "Variable"]
) -> Tuple["PandasMultiIndex", IndexVars]:
"""Re-arrange index levels using input order and return a new multi-index with
its corresponding coordinates.
"""
index = self.index.reorder_levels(level_variables.keys())
return self.from_pandas_index(
index, self.dim, var_meta=_get_var_metadata(level_variables)
)

@classmethod
def from_pandas_index(
cls,
Expand Down

0 comments on commit e50978e

Please sign in to comment.