Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate squeeze in GroupBy. #8506

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ Breaking changes

Deprecations
~~~~~~~~~~~~
- The `squeeze` kwarg to GroupBy is now deprecated. (:issue:`2157`)
By `Deepak Cherian <https://github.com/dcherian>`_.

- As part of an effort to standardize the API, we're renaming the ``dims``
keyword arg to ``dim`` for the minority of functions which current use
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6620,7 +6620,7 @@ def interp_calendar(
def groupby(
self,
group: Hashable | DataArray | IndexVariable,
squeeze: bool = True,
squeeze: bool | None = None,
restore_coord_dims: bool = False,
) -> DataArrayGroupBy:
"""Returns a DataArrayGroupBy object for performing grouped operations.
Expand Down
104 changes: 76 additions & 28 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ def check_reduce_dims(reduce_dims, dimensions):
)


def _maybe_squeeze_indices(
indices, squeeze: bool | None, grouper: ResolvedGrouper, warn: bool
):
if squeeze in [None, True] and grouper.can_squeeze:
if squeeze is None and warn:
emit_user_level_warning(
"The `squeeze` kwarg to GroupBy is being removed."
"Pass .groupby(..., squeeze=False) to silence this warning."
)
if isinstance(indices, slice):
assert indices.stop - indices.start == 1
indices = indices.start
return indices


def unique_value_groups(
ar, sort: bool = True
) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]:
Expand Down Expand Up @@ -366,10 +381,10 @@ def dims(self):
return self.group1d.dims

@abstractmethod
def _factorize(self, squeeze: bool) -> T_FactorizeOut:
def _factorize(self) -> T_FactorizeOut:
raise NotImplementedError

def factorize(self, squeeze: bool) -> None:
def factorize(self) -> None:
# This design makes it clear to mypy that
# codes, group_indices, unique_coord, and full_index
# are set by the factorize method on the derived class.
Expand All @@ -378,7 +393,7 @@ def factorize(self, squeeze: bool) -> None:
self.group_indices,
self.unique_coord,
self.full_index,
) = self._factorize(squeeze)
) = self._factorize()

@property
def is_unique_and_monotonic(self) -> bool:
Expand All @@ -393,15 +408,19 @@ def group_as_index(self) -> pd.Index:
self._group_as_index = self.group1d.to_index()
return self._group_as_index

@property
def can_squeeze(self):
is_dimension = self.group.dims == (self.group.name,)
return is_dimension and self.is_unique_and_monotonic


@dataclass
class ResolvedUniqueGrouper(ResolvedGrouper):
grouper: UniqueGrouper

def _factorize(self, squeeze) -> T_FactorizeOut:
is_dimension = self.group.dims == (self.group.name,)
if is_dimension and self.is_unique_and_monotonic:
return self._factorize_dummy(squeeze)
def factorize(self) -> T_FactorizeOut:
if self.can_squeeze:
return self._factorize_dummy()
else:
return self._factorize_unique()

Expand All @@ -424,15 +443,12 @@ def _factorize_unique(self) -> T_FactorizeOut:

return codes, group_indices, unique_coord, full_index

def _factorize_dummy(self, squeeze) -> T_FactorizeOut:
def _factorize_dummy(self) -> T_FactorizeOut:
size = self.group.size
# no need to factorize
if not squeeze:
# use slices to do views instead of fancy indexing
# equivalent to: group_indices = group_indices.reshape(-1, 1)
group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)]
else:
group_indices = list(range(size))
# use slices to do views instead of fancy indexing
# equivalent to: group_indices = group_indices.reshape(-1, 1)
group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)]
size_range = np.arange(size)
if isinstance(self.group, _DummyGroup):
codes = self.group.to_dataarray().copy(data=size_range)
Expand All @@ -448,7 +464,7 @@ def _factorize_dummy(self, squeeze) -> T_FactorizeOut:
class ResolvedBinGrouper(ResolvedGrouper):
grouper: BinGrouper

def _factorize(self, squeeze: bool) -> T_FactorizeOut:
def factorize(self) -> T_FactorizeOut:
from xarray.core.dataarray import DataArray

data = self.group1d.values
Expand Down Expand Up @@ -546,7 +562,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
_apply_loffset(self.grouper.loffset, first_items)
return first_items, codes

def _factorize(self, squeeze: bool) -> T_FactorizeOut:
def factorize(self) -> T_FactorizeOut:
full_index, first_items, codes_ = self._get_index_and_items()
sbins = first_items.values.astype(np.int64)
group_indices: T_GroupIndices = [
Expand Down Expand Up @@ -591,14 +607,14 @@ class TimeResampleGrouper(Grouper):
loffset: datetime.timedelta | str | None


def _validate_groupby_squeeze(squeeze: bool) -> None:
def _validate_groupby_squeeze(squeeze: bool | None) -> None:
# While we don't generally check the type of every arg, passing
# multiple dimensions as multiple arguments is common enough, and the
# consequences hidden enough (strings evaluate as true) to warrant
# checking here.
# A future version could make squeeze kwarg only, but would face
# backward-compat issues.
if not isinstance(squeeze, bool):
if squeeze is not None and not isinstance(squeeze, bool):
raise TypeError(f"`squeeze` must be True or False, but {squeeze} was supplied")


Expand Down Expand Up @@ -730,7 +746,7 @@ def __init__(
self._original_obj = obj

for grouper_ in self.groupers:
grouper_.factorize(squeeze)
grouper_._factorize()

(grouper,) = self.groupers
self._original_group = grouper.group
Expand Down Expand Up @@ -762,9 +778,14 @@ def sizes(self) -> Mapping[Hashable, int]:
Dataset.sizes
"""
if self._sizes is None:
self._sizes = self._obj.isel(
{self._group_dim: self._group_indices[0]}
).sizes
(grouper,) = self.groupers
index = _maybe_squeeze_indices(
self._group_indices[0],
self._squeeze,
grouper,
warn=True,
)
self._sizes = self._obj.isel({self._group_dim: index}).sizes

return self._sizes

Expand Down Expand Up @@ -798,14 +819,22 @@ def groups(self) -> dict[GroupKey, GroupIndex]:
# provided to mimic pandas.groupby
if self._groups is None:
(grouper,) = self.groupers
self._groups = dict(zip(grouper.unique_coord.values, self._group_indices))
squeezed_indices = (
_maybe_squeeze_indices(ind, self._squeeze, grouper, warn=idx > 0)
for idx, ind in enumerate(self._group_indices)
)
self._groups = dict(zip(grouper.unique_coord.values, squeezed_indices))
return self._groups

def __getitem__(self, key: GroupKey) -> T_Xarray:
"""
Get DataArray or Dataset corresponding to a particular group label.
"""
return self._obj.isel({self._group_dim: self.groups[key]})
(grouper,) = self.groupers
index = _maybe_squeeze_indices(
self.groups[key], self._squeeze, grouper, warn=True
)
return self._obj.isel({self._group_dim: index})

def __len__(self) -> int:
(grouper,) = self.groupers
Expand All @@ -826,7 +855,11 @@ def __repr__(self) -> str:

def _iter_grouped(self) -> Iterator[T_Xarray]:
"""Iterate over each element in this group"""
for indices in self._group_indices:
(grouper,) = self.groupers
for idx, indices in enumerate(self._group_indices):
indices = _maybe_squeeze_indices(
indices, self._squeeze, grouper, warn=idx > 0
)
yield self._obj.isel({self._group_dim: indices})

def _infer_concat_args(self, applied_example):
Expand Down Expand Up @@ -1309,7 +1342,11 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic):
@property
def dims(self) -> tuple[Hashable, ...]:
if self._dims is None:
self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims
(grouper,) = self.groupers
index = _maybe_squeeze_indices(
self._group_indices[0], self._squeeze, grouper, warn=True
)
self._dims = self._obj.isel({self._group_dim: index}).dims

return self._dims

Expand All @@ -1318,7 +1355,11 @@ def _iter_grouped_shortcut(self):
metadata
"""
var = self._obj.variable
for indices in self._group_indices:
(grouper,) = self.groupers
for idx, indices in enumerate(self._group_indices):
indices = _maybe_squeeze_indices(
indices, self._squeeze, grouper, warn=idx > 0
)
yield var[{self._group_dim: indices}]

def _concat_shortcut(self, applied, dim, positions=None):
Expand Down Expand Up @@ -1517,7 +1558,14 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic):
@property
def dims(self) -> Frozen[Hashable, int]:
if self._dims is None:
self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims
(grouper,) = self.groupers
index = _maybe_squeeze_indices(
self._group_indices[0],
self._squeeze,
grouper,
warn=True,
)
self._dims = self._obj.isel({self._group_dim: index}).dims

return self._dims

Expand Down
Loading
Loading