Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flox: Properly propagate multiindex #9649

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Bug fixes
the non-missing times could in theory be encoded with integers
(:issue:`9488`, :pull:`9497`). By `Spencer Clark
<https://github.com/spencerkclark>`_.
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`, :issue:`9648`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Fix the safe_chunks validation option on the to_zarr method
(:issue:`5511`, :pull:`9559`). By `Joseph Nowak
Expand Down
11 changes: 11 additions & 0 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,3 +1116,14 @@ def create_coords_with_default_indexes(
new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes)

return new_coords


def _coordinates_from_variable(variable: Variable) -> Coordinates:
from xarray.core.indexes import create_default_index_implicit

(name,) = variable.dims
new_index, index_vars = create_default_index_implicit(variable)
indexes = {k: new_index for k in index_vars}
new_vars = new_index.create_variables()
new_vars[name].attrs = variable.attrs
return Coordinates(new_vars, indexes)
40 changes: 17 additions & 23 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.concat import concat
from xarray.core.coordinates import Coordinates
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.formatting import format_array_flat
from xarray.core.indexes import (
PandasIndex,
PandasMultiIndex,
filter_indexes_from_coords,
)
from xarray.core.merge import merge_coords
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import (
Dims,
Expand Down Expand Up @@ -851,7 +851,6 @@ def _flox_reduce(
from flox.xarray import xarray_reduce

from xarray.core.dataset import Dataset
from xarray.groupers import BinGrouper

obj = self._original_obj
variables = (
Expand Down Expand Up @@ -901,13 +900,6 @@ def _flox_reduce(
# set explicitly to avoid unnecessarily accumulating count
kwargs["min_count"] = 0

unindexed_dims: tuple[Hashable, ...] = tuple(
grouper.name
for grouper in self.groupers
if isinstance(grouper.group, _DummyGroup)
and not isinstance(grouper.grouper, BinGrouper)
)

parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
Expand Down Expand Up @@ -963,26 +955,28 @@ def _flox_reduce(
# we did end up reducing over dimension(s) that are
# in the grouped variable
group_dims = set(grouper.group.dims)
new_coords = {}
new_coords = []
if group_dims.issubset(set(parsed_dim)):
new_indexes = {}
for grouper in self.groupers:
output_index = grouper.full_index
if isinstance(output_index, pd.RangeIndex):
# flox always assigns an index so we must drop it here if we don't need it.
result = result.drop_vars(grouper.name)
continue
name = grouper.name
new_coords[name] = IndexVariable(
dims=name, data=np.array(output_index), attrs=grouper.codes.attrs
new_coords.append(
# Using IndexVariable here ensures we reconstruct PandasMultiIndex with
# all associated levels properly.
_coordinates_from_variable(
IndexVariable(
dims=grouper.name,
data=output_index,
attrs=grouper.codes.attrs,
)
)
)
index_cls = (
PandasIndex
if not isinstance(output_index, pd.MultiIndex)
else PandasMultiIndex
)
new_indexes[name] = index_cls(output_index, dim=name)
result = result.assign_coords(
Coordinates(new_coords, new_indexes)
).drop_vars(unindexed_dims)
Coordinates._construct_direct(*merge_coords(new_coords))
)

# broadcast any non-dim coord variables that don't
# share all dimensions with the grouper
Expand Down
13 changes: 1 addition & 12 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
from xarray.core import duck_array_ops
from xarray.core.coordinates import Coordinates
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.dataarray import DataArray
from xarray.core.groupby import T_Group, _DummyGroup
from xarray.core.indexes import safe_cast_to_index
Expand All @@ -42,17 +42,6 @@
RESAMPLE_DIM = "__resample_dim__"


def _coordinates_from_variable(variable: Variable) -> Coordinates:
from xarray.core.indexes import create_default_index_implicit

(name,) = variable.dims
new_index, index_vars = create_default_index_implicit(variable)
indexes = {k: new_index for k in index_vars}
new_vars = new_index.create_variables()
new_vars[name].attrs = variable.attrs
return Coordinates(new_vars, indexes)


@dataclass(init=False)
class EncodedGroups:
"""
Expand Down
19 changes: 19 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,25 @@
assert_equal(expected, actual)


def test_multi_index_propagation():
# regression test for GH9648
times = pd.date_range("2023-01-01", periods=4)
locations = ["A", "B"]
data = [[0.5, 0.7], [0.6, 0.5], [0.4, 0.6], [0.4, 0.9]]

da = xr.DataArray(
data, dims=["time", "location"], coords={"time": times, "location": locations}
)
da = da.stack(multiindex=["time", "location"])
grouped = da.groupby("multiindex")

with xr.set_options(use_flox=True):
actual = grouped.sum()

Check failure on line 161 in xarray/tests/test_groupby.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 min-all-deps

test_multi_index_propagation ValueError: Buffer dtype mismatch, expected 'Python object' but got 'long'
with xr.set_options(use_flox=False):
expected = grouped.first()
assert_identical(actual, expected)


def test_groupby_da_datetime() -> None:
# test groupby with a DataArray of dtype datetime for GH1132
# create test data
Expand Down
Loading