Skip to content

Commit

Permalink
Set pd.MultiIndex name from dim name
Browse files Browse the repository at this point in the history
  • Loading branch information
benbovy committed Sep 15, 2021
1 parent e50978e commit 8b1e4d5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
5 changes: 4 additions & 1 deletion xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None):
def _replace(self, index, dim=None, level_coords_dtype=None) -> "PandasMultiIndex":
if dim is None:
dim = self.dim
index.name = dim
if level_coords_dtype is None:
level_coords_dtype = self.level_coords_dtype
return type(self)(index, dim, level_coords_dtype)
Expand All @@ -410,6 +411,7 @@ def from_variables(
index = pd.MultiIndex.from_arrays(
[var.values for var in variables.values()], names=variables.keys()
)
index.name = dim
level_coords_dtype = {name: var.dtype for name, var in variables.items()}
obj = cls(index, dim, level_coords_dtype=level_coords_dtype)

Expand Down Expand Up @@ -529,6 +531,7 @@ def from_pandas_index(
level_coords_dtype = {k: var_meta[k]["dtype"] for k in names}

index = index.rename(names)
index.name = dim
index_vars = _create_variables_from_multiindex(index, dim, var_meta=var_meta)
return cls(index, dim, level_coords_dtype=level_coords_dtype), index_vars

Expand Down Expand Up @@ -628,7 +631,7 @@ def query(self, labels, method=None, tolerance=None) -> QueryResult:
var_meta = {k: {"dtype": v} for k, v in self.level_coords_dtype.items()}

if isinstance(new_index, pd.MultiIndex):
new_index, new_vars = PandasMultiIndex.from_pandas_index(
new_index, new_vars = self.from_pandas_index(
new_index, self.dim, var_meta=var_meta
)
dims_dict = {}
Expand Down
6 changes: 4 additions & 2 deletions xarray/tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_from_variables(self) -> None:
expected_idx = pd.MultiIndex.from_arrays([v_level1.data, v_level2.data])
assert index.dim == "x"
assert index.index.equals(expected_idx)
assert index.index.name == "x"

assert list(index_vars) == ["x", "level1", "level2"]
xr.testing.assert_equal(xr.IndexVariable("x", expected_idx), index_vars["x"])
Expand All @@ -200,6 +201,7 @@ def test_from_pandas_index(self) -> None:
assert index.dim == "x"
assert index.index.equals(pd_idx)
assert index.index.names == ("foo", "bar")
assert index.index.name == "x"
xr.testing.assert_identical(index_vars["x"], IndexVariable("x", pd_idx))
xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", foo_data))
xr.testing.assert_identical(index_vars["bar"], IndexVariable("x", bar_data))
Expand Down Expand Up @@ -231,7 +233,7 @@ def test_query(self) -> None:
index.query({"x": (slice(None), 1, "no_level")})

def test_rename(self) -> None:
level_coords_dtype = {"one": "U<1", "two": np.int32}
level_coords_dtype = {"one": "<U1", "two": np.int32}
index = PandasMultiIndex(
pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")),
"x",
Expand All @@ -246,7 +248,7 @@ def test_rename(self) -> None:
new_index, index_vars = index.rename({"two": "three"}, {})
assert new_index.index.names == ["one", "three"]
assert new_index.dim == "x"
assert new_index.level_coords_dtype == {"one": "U<1", "three": np.int32}
assert new_index.level_coords_dtype == {"one": "<U1", "three": np.int32}
assert list(index_vars.keys()) == ["x", "one", "three"]
for v in index_vars.values():
assert v.dims == ("x",)
Expand Down

0 comments on commit 8b1e4d5

Please sign in to comment.