diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index f038911ddbe..0425da2ffc7 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -44,6 +44,7 @@ from xarray.core.utils import Frozen GroupKey = Any + GroupIndex = int | slice | list[int] T_GroupIndicesListInt = list[list[int]] T_GroupIndices = Union[T_GroupIndicesListInt, list[slice], np.ndarray] @@ -129,11 +130,11 @@ def _dummy_copy(xarray_obj): return res -def _is_one_or_none(obj): +def _is_one_or_none(obj) -> bool: return obj == 1 or obj is None -def _consolidate_slices(slices): +def _consolidate_slices(slices: list[slice]) -> list[slice]: """Consolidate adjacent slices in a list of slices.""" result = [] last_slice = slice(None) @@ -191,7 +192,6 @@ def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None: self.name = name self.coords = coords self.size = obj.sizes[name] - self.dataarray = obj[name] @property def dims(self) -> tuple[Hashable]: @@ -228,6 +228,13 @@ def __getitem__(self, key): def copy(self, deep: bool = True, data: Any = None): raise NotImplementedError + def as_dataarray(self) -> DataArray: + from xarray.core.dataarray import DataArray + + return DataArray( + data=self.data, dims=(self.name,), coords=self.coords, name=self.name + ) + T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup]) @@ -294,14 +301,16 @@ def _apply_loffset( class Grouper: - def __init__(self, group: T_Group): - self.group : T_Group | None = group - self.codes : np.ndarry | None = None + def __init__(self, group: T_Group | Hashable): + self.group: T_Group | Hashable = group + self.labels = None - self.group_indices : list[list[int, ...]] | None= None - self.unique_coord = None - self.full_index : pd.Index | None = None - self._group_as_index = None + self._group_as_index: pd.Index | None = None + + self.codes: DataArray + self.group_indices: list[int] | list[slice] | list[list[int]] + self.unique_coord: IndexVariable | _DummyGroup + self.full_index: pd.Index @property def name(self) -> Hashable: @@ -334,10 +343,9 @@ def group_as_index(self) -> pd.Index: self._group_as_index = safe_cast_to_index(self.group1d) return self._group_as_index - def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None: + def _resolve_group(self, obj: T_Xarray): from xarray.core.dataarray import DataArray - group: T_Group group = self.group if not isinstance(group, (DataArray, IndexVariable)): if not hashable(group): @@ -346,15 +354,14 @@ def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None: "name of an xarray variable or dimension. " f"Received {group!r} instead." ) - group_da : T_DataArray = obj[group] - if len(group_da) == 0: - raise ValueError(f"{group_da.name} must not be empty") - - if group_da.name not in obj.coords and group_da.name in obj.dims: + group = obj[group] + if len(group) == 0: + raise ValueError(f"{group.name} must not be empty") + if group.name not in obj._indexes and group.name in obj.dims: # DummyGroups should not appear on groupby results group = _DummyGroup(obj, group.name, group.coords) - if getattr(group, "name", None) is None: + elif getattr(group, "name", None) is None: group.name = "group" self.group = group @@ -408,10 +415,10 @@ def _factorize_dummy(self, squeeze) -> None: # equivalent to: group_indices = group_indices.reshape(-1, 1) self.group_indices = [slice(i, i + 1) for i in range(size)] else: - self.group_indices = np.arange(size) + self.group_indices = list(range(size)) codes = np.arange(size) if isinstance(self.group, _DummyGroup): - self.codes = self.group.dataarray.copy(data=codes) + self.codes = self.group.as_dataarray().copy(data=codes) else: self.codes = self.group.copy(data=codes) self.unique_coord = self.group @@ -489,7 +496,7 @@ def __init__( raise ValueError("index must be monotonic for resampling") if isinstance(group_as_index, CFTimeIndex): - self.grouper = CFTimeGrouper( + grouper = CFTimeGrouper( freq=self.freq, closed=self.closed, label=self.label, @@ -498,15 +505,16 @@ def __init__( loffset=self.loffset, ) else: - self.grouper = pd.Grouper( + grouper = pd.Grouper( freq=self.freq, closed=self.closed, label=self.label, origin=self.origin, offset=self.offset, ) + self.grouper: CFTimeGrouper | pd.Grouper = grouper - def _get_index_and_items(self): + def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: first_items, codes = self.first_items() full_index = first_items.index if first_items.isnull().any(): @@ -515,7 +523,7 @@ def _get_index_and_items(self): full_index = full_index.rename("__resample_dim__") return full_index, first_items, codes - def first_items(self): + def first_items(self) -> tuple[pd.Series, np.ndarray]: from xarray import CFTimeIndex if isinstance(self.group_as_index, CFTimeIndex): @@ -670,7 +678,7 @@ def reduce( raise NotImplementedError() @property - def groups(self) -> dict[GroupKey, slice | int | list[int]]: + def groups(self) -> dict[GroupKey, GroupIndex]: """ Mapping from group labels to indices. The indices can be used to index the underlying object. """ @@ -735,7 +743,7 @@ def _binary_op(self, other, f, reflexive=False): dims = group.dims if isinstance(group, _DummyGroup): - group = coord = group.dataarray + group = coord = group.as_dataarray() else: coord = grouper.unique_coord if not isinstance(coord, DataArray):