Skip to content

Commit

Permalink
groupby: remove some internal use of IndexVariable (#9123)
Browse files Browse the repository at this point in the history
* Remove internal use of IndexVariable

* cleanup

* cleanup more

* cleanup
  • Loading branch information
dcherian authored Jun 21, 2024
1 parent af722f0 commit 2645d7f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 33 deletions.
63 changes: 41 additions & 22 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.concat import concat
from xarray.core.coordinates import Coordinates
from xarray.core.formatting import format_array_flat
from xarray.core.indexes import (
PandasIndex,
create_default_index_implicit,
filter_indexes_from_coords,
)
Expand Down Expand Up @@ -246,7 +248,7 @@ def to_array(self) -> DataArray:
return self.to_dataarray()


T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup]
T_Group = Union["T_DataArray", _DummyGroup]


def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
Expand All @@ -256,7 +258,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
list[Hashable],
]:
# 1D cases: do nothing
if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1:
if isinstance(group, _DummyGroup) or group.ndim == 1:
return group, obj, None, []

from xarray.core.dataarray import DataArray
Expand All @@ -271,9 +273,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
newobj = obj.stack({stacked_dim: orig_dims})
return newgroup, newobj, stacked_dim, inserted_dims

raise TypeError(
f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}."
)
raise TypeError(f"group must be DataArray or _DummyGroup, got {type(group)!r}.")


@dataclass
Expand All @@ -299,7 +299,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
codes: DataArray = field(init=False)
full_index: pd.Index = field(init=False)
group_indices: T_GroupIndices = field(init=False)
unique_coord: IndexVariable | _DummyGroup = field(init=False)
unique_coord: Variable | _DummyGroup = field(init=False)

# _ensure_1d:
group1d: T_Group = field(init=False)
Expand All @@ -315,7 +315,7 @@ def __post_init__(self) -> None:
# might be used multiple times.
self.grouper = copy.deepcopy(self.grouper)

self.group: T_Group = _resolve_group(self.obj, self.group)
self.group = _resolve_group(self.obj, self.group)

(
self.group1d,
Expand All @@ -328,14 +328,18 @@ def __post_init__(self) -> None:

@property
def name(self) -> Hashable:
"""Name for the grouped coordinate after reduction."""
# the name has to come from unique_coord because we need `_bins` suffix for BinGrouper
return self.unique_coord.name
(name,) = self.unique_coord.dims
return name

@property
def size(self) -> int:
"""Number of groups."""
return len(self)

def __len__(self) -> int:
"""Number of groups."""
return len(self.full_index)

@property
Expand All @@ -358,8 +362,8 @@ def factorize(self) -> None:
]
if encoded.unique_coord is None:
unique_values = self.full_index[np.unique(encoded.codes)]
self.unique_coord = IndexVariable(
self.codes.name, unique_values, attrs=self.group.attrs
self.unique_coord = Variable(
dims=self.codes.name, data=unique_values, attrs=self.group.attrs
)
else:
self.unique_coord = encoded.unique_coord
Expand All @@ -378,7 +382,9 @@ def _validate_groupby_squeeze(squeeze: bool | None) -> None:
)


def _resolve_group(obj: T_DataWithCoords, group: T_Group | Hashable) -> T_Group:
def _resolve_group(
obj: T_DataWithCoords, group: T_Group | Hashable | IndexVariable
) -> T_Group:
from xarray.core.dataarray import DataArray

error_msg = (
Expand Down Expand Up @@ -620,6 +626,8 @@ def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]:
yield self._obj.isel({self._group_dim: indices})

def _infer_concat_args(self, applied_example):
from xarray.core.groupers import BinGrouper

(grouper,) = self.groupers
if self._group_dim in applied_example.dims:
coord = grouper.group1d
Expand All @@ -628,7 +636,10 @@ def _infer_concat_args(self, applied_example):
coord = grouper.unique_coord
positions = None
(dim,) = coord.dims
if isinstance(coord, _DummyGroup):
if isinstance(grouper.group, _DummyGroup) and not isinstance(
grouper.grouper, BinGrouper
):
# When binning we actually do set the index
coord = None
coord = getattr(coord, "variable", coord)
return coord, dim, positions
Expand All @@ -641,6 +652,7 @@ def _binary_op(self, other, f, reflexive=False):

(grouper,) = self.groupers
obj = self._original_obj
name = grouper.name
group = grouper.group
codes = self._codes
dims = group.dims
Expand All @@ -649,9 +661,11 @@ def _binary_op(self, other, f, reflexive=False):
group = coord = group.to_dataarray()
else:
coord = grouper.unique_coord
if not isinstance(coord, DataArray):
coord = DataArray(grouper.unique_coord)
name = grouper.name
if isinstance(coord, Variable):
assert coord.ndim == 1
(coord_dim,) = coord.dims
# TODO: explicitly create Index here
coord = DataArray(coord, coords={coord_dim: coord.data})

if not isinstance(other, (Dataset, DataArray)):
raise TypeError(
Expand Down Expand Up @@ -766,6 +780,7 @@ def _flox_reduce(

obj = self._original_obj
(grouper,) = self.groupers
name = grouper.name
isbin = isinstance(grouper.grouper, BinGrouper)

if keep_attrs is None:
Expand Down Expand Up @@ -797,14 +812,14 @@ def _flox_reduce(
# weird backcompat
# reducing along a unique indexed dimension with squeeze=True
# should raise an error
if (dim is None or dim == grouper.name) and grouper.name in obj.xindexes:
index = obj.indexes[grouper.name]
if (dim is None or dim == name) and name in obj.xindexes:
index = obj.indexes[name]
if index.is_unique and self._squeeze:
raise ValueError(f"cannot reduce over dimensions {grouper.name!r}")
raise ValueError(f"cannot reduce over dimensions {name!r}")

unindexed_dims: tuple[Hashable, ...] = tuple()
if isinstance(grouper.group, _DummyGroup) and not isbin:
unindexed_dims = (grouper.name,)
unindexed_dims = (name,)

parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
Expand Down Expand Up @@ -848,15 +863,19 @@ def _flox_reduce(
# in the grouped variable
group_dims = grouper.group.dims
if set(group_dims).issubset(set(parsed_dim)):
result[grouper.name] = output_index
result = result.assign_coords(
Coordinates(
coords={name: (name, np.array(output_index))},
indexes={name: PandasIndex(output_index, dim=name)},
)
)
result = result.drop_vars(unindexed_dims)

# broadcast and restore non-numeric data variables (backcompat)
for name, var in non_numeric.items():
if all(d not in var.dims for d in parsed_dim):
result[name] = var.variable.set_dims(
(grouper.name,) + var.dims,
(result.sizes[grouper.name],) + var.shape,
(name,) + var.dims, (result.sizes[name],) + var.shape
)

if not isinstance(result, Dataset):
Expand Down
37 changes: 26 additions & 11 deletions xarray/core/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from xarray.core.resample_cftime import CFTimeGrouper
from xarray.core.types import DatetimeLike, SideOptions, T_GroupIndices
from xarray.core.utils import emit_user_level_warning
from xarray.core.variable import IndexVariable
from xarray.core.variable import Variable

__all__ = [
"EncodedGroups",
Expand Down Expand Up @@ -55,7 +55,17 @@ class EncodedGroups:
codes: DataArray
full_index: pd.Index
group_indices: T_GroupIndices | None = field(default=None)
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)
unique_coord: Variable | _DummyGroup | None = field(default=None)

def __post_init__(self):
assert isinstance(self.codes, DataArray)
if self.codes.name is None:
raise ValueError("Please set a name on the array you are grouping by.")
assert isinstance(self.full_index, pd.Index)
assert (
isinstance(self.unique_coord, (Variable, _DummyGroup))
or self.unique_coord is None
)


class Grouper(ABC):
Expand Down Expand Up @@ -134,10 +144,10 @@ def _factorize_unique(self) -> EncodedGroups:
"Failed to group data. Are you grouping by a variable that is all NaN?"
)
codes = self.group.copy(data=codes_)
unique_coord = IndexVariable(
self.group.name, unique_values, attrs=self.group.attrs
unique_coord = Variable(
dims=codes.name, data=unique_values, attrs=self.group.attrs
)
full_index = unique_coord
full_index = pd.Index(unique_values)

return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
Expand All @@ -152,12 +162,13 @@ def _factorize_dummy(self) -> EncodedGroups:
size_range = np.arange(size)
if isinstance(self.group, _DummyGroup):
codes = self.group.to_dataarray().copy(data=size_range)
unique_coord = self.group
full_index = pd.RangeIndex(self.group.size)
else:
codes = self.group.copy(data=size_range)
unique_coord = self.group
full_index = IndexVariable(
self.group.name, unique_coord.values, self.group.attrs
)
unique_coord = self.group.variable.to_base_variable()
full_index = pd.Index(unique_coord.data)

return EncodedGroups(
codes=codes,
group_indices=group_indices,
Expand Down Expand Up @@ -201,7 +212,9 @@ def factorize(self, group) -> EncodedGroups:
codes = DataArray(
binned_codes, getattr(group, "coords", None), name=new_dim_name
)
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
unique_coord = Variable(
dims=new_dim_name, data=unique_values, attrs=group.attrs
)
return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
)
Expand Down Expand Up @@ -318,7 +331,9 @@ def factorize(self, group) -> EncodedGroups:
]
group_indices += [slice(sbins[-1], None)]

unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
unique_coord = Variable(
dims=group.name, data=first_items.index, attrs=group.attrs
)
codes = group.copy(data=codes_)

return EncodedGroups(
Expand Down

0 comments on commit 2645d7f

Please sign in to comment.