From 2645d7f6d95abffb04393c7ef1692125ee4ba869 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 21 Jun 2024 16:35:13 -0600 Subject: [PATCH] groupby: remove some internal use of IndexVariable (#9123) * Remove internal use of IndexVariable * cleanup * cleanup more * cleanup --- xarray/core/groupby.py | 63 +++++++++++++++++++++++++++-------------- xarray/core/groupers.py | 37 +++++++++++++++++------- 2 files changed, 67 insertions(+), 33 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index eceb4e62199..42e7f01a526 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -19,8 +19,10 @@ from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce from xarray.core.concat import concat +from xarray.core.coordinates import Coordinates from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( + PandasIndex, create_default_index_implicit, filter_indexes_from_coords, ) @@ -246,7 +248,7 @@ def to_array(self) -> DataArray: return self.to_dataarray() -T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup] +T_Group = Union["T_DataArray", _DummyGroup] def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ @@ -256,7 +258,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ list[Hashable], ]: # 1D cases: do nothing - if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1: + if isinstance(group, _DummyGroup) or group.ndim == 1: return group, obj, None, [] from xarray.core.dataarray import DataArray @@ -271,9 +273,7 @@ def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[ newobj = obj.stack({stacked_dim: orig_dims}) return newgroup, newobj, stacked_dim, inserted_dims - raise TypeError( - f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}." - ) + raise TypeError(f"group must be DataArray or _DummyGroup, got {type(group)!r}.") @dataclass @@ -299,7 +299,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]): codes: DataArray = field(init=False) full_index: pd.Index = field(init=False) group_indices: T_GroupIndices = field(init=False) - unique_coord: IndexVariable | _DummyGroup = field(init=False) + unique_coord: Variable | _DummyGroup = field(init=False) # _ensure_1d: group1d: T_Group = field(init=False) @@ -315,7 +315,7 @@ def __post_init__(self) -> None: # might be used multiple times. self.grouper = copy.deepcopy(self.grouper) - self.group: T_Group = _resolve_group(self.obj, self.group) + self.group = _resolve_group(self.obj, self.group) ( self.group1d, @@ -328,14 +328,18 @@ def __post_init__(self) -> None: @property def name(self) -> Hashable: + """Name for the grouped coordinate after reduction.""" # the name has to come from unique_coord because we need `_bins` suffix for BinGrouper - return self.unique_coord.name + (name,) = self.unique_coord.dims + return name @property def size(self) -> int: + """Number of groups.""" return len(self) def __len__(self) -> int: + """Number of groups.""" return len(self.full_index) @property @@ -358,8 +362,8 @@ def factorize(self) -> None: ] if encoded.unique_coord is None: unique_values = self.full_index[np.unique(encoded.codes)] - self.unique_coord = IndexVariable( - self.codes.name, unique_values, attrs=self.group.attrs + self.unique_coord = Variable( + dims=self.codes.name, data=unique_values, attrs=self.group.attrs ) else: self.unique_coord = encoded.unique_coord @@ -378,7 +382,9 @@ def _validate_groupby_squeeze(squeeze: bool | None) -> None: ) -def _resolve_group(obj: T_DataWithCoords, group: T_Group | Hashable) -> T_Group: +def _resolve_group( + obj: T_DataWithCoords, group: T_Group | Hashable | IndexVariable +) -> T_Group: from xarray.core.dataarray import DataArray error_msg = ( @@ -620,6 +626,8 @@ def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]: yield self._obj.isel({self._group_dim: indices}) def _infer_concat_args(self, applied_example): + from xarray.core.groupers import BinGrouper + (grouper,) = self.groupers if self._group_dim in applied_example.dims: coord = grouper.group1d @@ -628,7 +636,10 @@ def _infer_concat_args(self, applied_example): coord = grouper.unique_coord positions = None (dim,) = coord.dims - if isinstance(coord, _DummyGroup): + if isinstance(grouper.group, _DummyGroup) and not isinstance( + grouper.grouper, BinGrouper + ): + # When binning we actually do set the index coord = None coord = getattr(coord, "variable", coord) return coord, dim, positions @@ -641,6 +652,7 @@ def _binary_op(self, other, f, reflexive=False): (grouper,) = self.groupers obj = self._original_obj + name = grouper.name group = grouper.group codes = self._codes dims = group.dims @@ -649,9 +661,11 @@ def _binary_op(self, other, f, reflexive=False): group = coord = group.to_dataarray() else: coord = grouper.unique_coord - if not isinstance(coord, DataArray): - coord = DataArray(grouper.unique_coord) - name = grouper.name + if isinstance(coord, Variable): + assert coord.ndim == 1 + (coord_dim,) = coord.dims + # TODO: explicitly create Index here + coord = DataArray(coord, coords={coord_dim: coord.data}) if not isinstance(other, (Dataset, DataArray)): raise TypeError( @@ -766,6 +780,7 @@ def _flox_reduce( obj = self._original_obj (grouper,) = self.groupers + name = grouper.name isbin = isinstance(grouper.grouper, BinGrouper) if keep_attrs is None: @@ -797,14 +812,14 @@ def _flox_reduce( # weird backcompat # reducing along a unique indexed dimension with squeeze=True # should raise an error - if (dim is None or dim == grouper.name) and grouper.name in obj.xindexes: - index = obj.indexes[grouper.name] + if (dim is None or dim == name) and name in obj.xindexes: + index = obj.indexes[name] if index.is_unique and self._squeeze: - raise ValueError(f"cannot reduce over dimensions {grouper.name!r}") + raise ValueError(f"cannot reduce over dimensions {name!r}") unindexed_dims: tuple[Hashable, ...] = tuple() if isinstance(grouper.group, _DummyGroup) and not isbin: - unindexed_dims = (grouper.name,) + unindexed_dims = (name,) parsed_dim: tuple[Hashable, ...] if isinstance(dim, str): @@ -848,15 +863,19 @@ def _flox_reduce( # in the grouped variable group_dims = grouper.group.dims if set(group_dims).issubset(set(parsed_dim)): - result[grouper.name] = output_index + result = result.assign_coords( + Coordinates( + coords={name: (name, np.array(output_index))}, + indexes={name: PandasIndex(output_index, dim=name)}, + ) + ) result = result.drop_vars(unindexed_dims) # broadcast and restore non-numeric data variables (backcompat) for name, var in non_numeric.items(): if all(d not in var.dims for d in parsed_dim): result[name] = var.variable.set_dims( - (grouper.name,) + var.dims, - (result.sizes[grouper.name],) + var.shape, + (name,) + var.dims, (result.sizes[name],) + var.shape ) if not isinstance(result, Dataset): diff --git a/xarray/core/groupers.py b/xarray/core/groupers.py index e33cd3ad99f..075afd9f62f 100644 --- a/xarray/core/groupers.py +++ b/xarray/core/groupers.py @@ -23,7 +23,7 @@ from xarray.core.resample_cftime import CFTimeGrouper from xarray.core.types import DatetimeLike, SideOptions, T_GroupIndices from xarray.core.utils import emit_user_level_warning -from xarray.core.variable import IndexVariable +from xarray.core.variable import Variable __all__ = [ "EncodedGroups", @@ -55,7 +55,17 @@ class EncodedGroups: codes: DataArray full_index: pd.Index group_indices: T_GroupIndices | None = field(default=None) - unique_coord: IndexVariable | _DummyGroup | None = field(default=None) + unique_coord: Variable | _DummyGroup | None = field(default=None) + + def __post_init__(self): + assert isinstance(self.codes, DataArray) + if self.codes.name is None: + raise ValueError("Please set a name on the array you are grouping by.") + assert isinstance(self.full_index, pd.Index) + assert ( + isinstance(self.unique_coord, (Variable, _DummyGroup)) + or self.unique_coord is None + ) class Grouper(ABC): @@ -134,10 +144,10 @@ def _factorize_unique(self) -> EncodedGroups: "Failed to group data. Are you grouping by a variable that is all NaN?" ) codes = self.group.copy(data=codes_) - unique_coord = IndexVariable( - self.group.name, unique_values, attrs=self.group.attrs + unique_coord = Variable( + dims=codes.name, data=unique_values, attrs=self.group.attrs ) - full_index = unique_coord + full_index = pd.Index(unique_values) return EncodedGroups( codes=codes, full_index=full_index, unique_coord=unique_coord @@ -152,12 +162,13 @@ def _factorize_dummy(self) -> EncodedGroups: size_range = np.arange(size) if isinstance(self.group, _DummyGroup): codes = self.group.to_dataarray().copy(data=size_range) + unique_coord = self.group + full_index = pd.RangeIndex(self.group.size) else: codes = self.group.copy(data=size_range) - unique_coord = self.group - full_index = IndexVariable( - self.group.name, unique_coord.values, self.group.attrs - ) + unique_coord = self.group.variable.to_base_variable() + full_index = pd.Index(unique_coord.data) + return EncodedGroups( codes=codes, group_indices=group_indices, @@ -201,7 +212,9 @@ def factorize(self, group) -> EncodedGroups: codes = DataArray( binned_codes, getattr(group, "coords", None), name=new_dim_name ) - unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs) + unique_coord = Variable( + dims=new_dim_name, data=unique_values, attrs=group.attrs + ) return EncodedGroups( codes=codes, full_index=full_index, unique_coord=unique_coord ) @@ -318,7 +331,9 @@ def factorize(self, group) -> EncodedGroups: ] group_indices += [slice(sbins[-1], None)] - unique_coord = IndexVariable(group.name, first_items.index, group.attrs) + unique_coord = Variable( + dims=group.name, data=first_items.index, attrs=group.attrs + ) codes = group.copy(data=codes_) return EncodedGroups(