From 22ad7fa7607cb83832935533a55df1f73c65811d Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 18 Mar 2023 20:40:22 -0600 Subject: [PATCH] [WIP] --- xarray/core/common.py | 77 ++++++++++++++++++++++++++++++++++-- xarray/core/dataarray.py | 52 ++++++------------------ xarray/core/dataset.py | 43 +++++++------------- xarray/core/groupby.py | 85 ++++++++++++++++++++++++---------------- 4 files changed, 153 insertions(+), 104 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 1c6118e8d4c..cb7df60cffa 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -814,6 +814,74 @@ def rolling_exp( return rolling_exp.RollingExp(self, window, window_type) + def _groupby(self, groupby_cls, group, squeeze: bool, restore_coord_dims): + from xarray.core.groupby import UniqueGrouper, _validate_group + + # While we don't generally check the type of every arg, passing + # multiple dimensions as multiple arguments is common enough, and the + # consequences hidden enough (strings evaluate as true) to warrant + # checking here. + # A future version could make squeeze kwarg only, but would face + # backward-compat issues. + if not isinstance(squeeze, bool): + raise TypeError( + f"`squeeze` must be True or False, but {squeeze} was supplied" + ) + + newobj, name = _validate_group(self, group) + + grouper = UniqueGrouper() + return groupby_cls( + newobj, + {name: grouper}, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) + + def _groupby_bins( + self, + groupby_cls, + group: Hashable | DataArray | IndexVariable, + bins: ArrayLike, + right: bool = True, + labels: ArrayLike | None = None, + precision: int = 3, + include_lowest: bool = False, + squeeze: bool = True, + restore_coord_dims: bool = False, + ): + from xarray.core.groupby import BinGrouper, _validate_group + + # While we don't generally check the type of every arg, passing + # multiple dimensions as multiple arguments is common enough, and the + # consequences hidden enough (strings evaluate as true) to warrant + # checking here. + # A future version could make squeeze kwarg only, but would face + # backward-compat issues. + if not isinstance(squeeze, bool): + raise TypeError( + f"`squeeze` must be True or False, but {squeeze} was supplied" + ) + + newobj, name = _validate_group(self, group) + + grouper = BinGrouper( + bins=bins, + cut_kwargs={ + "right": right, + "labels": labels, + "precision": precision, + "include_lowest": include_lowest, + }, + ) + + return groupby_cls( + newobj, + {name: grouper}, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, + ) + def _resample( self, resample_cls: type[T_Resample], @@ -1000,12 +1068,13 @@ def _resample( if base is not None: offset = _convert_base_to_offset(base, freq, index) + name = RESAMPLE_DIM group = DataArray( - dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM + dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=name ) + newobj = self.copy().assign_coords({name: group}) grouper = TimeResampleGrouper( - group=group, freq=freq, closed=closed, label=label, @@ -1015,8 +1084,8 @@ def _resample( ) return resample_cls( - self, - grouper=grouper, + newobj, + {name: grouper}, dim=dim_name, resample_dim=RESAMPLE_DIM, restore_coord_dims=restore_coord_dims, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ed1ae078710..b6bf27f0000 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6256,22 +6256,13 @@ def groupby( core.groupby.DataArrayGroupBy pandas.DataFrame.groupby """ - from xarray.core.groupby import DataArrayGroupBy, UniqueGrouper - - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) + from xarray.core.groupby import DataArrayGroupBy - grouper = UniqueGrouper(group) - return DataArrayGroupBy( - self, grouper, squeeze=squeeze, restore_coord_dims=restore_coord_dims + return self._groupby( + groupby_cls=DataArrayGroupBy, + group=group, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, ) def groupby_bins( @@ -6342,33 +6333,16 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import BinGrouper, DataArrayGroupBy - - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) + from xarray.core.groupby import DataArrayGroupBy - grouper = BinGrouper( + return self._groupby_bins( + groupby_cls=DataArrayGroupBy, group=group, bins=bins, - cut_kwargs={ - "right": right, - "labels": labels, - "precision": precision, - "include_lowest": include_lowest, - }, - ) - - return DataArrayGroupBy( - self, - grouper, + right=right, + labels=labels, + precision=precision, + include_lowest=include_lowest, squeeze=squeeze, restore_coord_dims=restore_coord_dims, ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 005801573ff..ec1d857f563 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8943,23 +8943,13 @@ def groupby( Dataset.resample DataArray.resample """ - from xarray.core.groupby import DatasetGroupBy, UniqueGrouper - - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) + from xarray.core.groupby import DatasetGroupBy - grouper = UniqueGrouper(group) - - return DatasetGroupBy( - self, grouper, squeeze=squeeze, restore_coord_dims=restore_coord_dims + return self._groupby( + groupby_cls=DatasetGroupBy, + group=group, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, ) def groupby_bins( @@ -9030,21 +9020,18 @@ def groupby_bins( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html """ - from xarray.core.groupby import BinGrouper, DatasetGroupBy + from xarray.core.groupby import DatasetGroupBy - grouper = BinGrouper( + return self._groupby_bins( + groupby_cls=DatasetGroupBy, group=group, bins=bins, - cut_kwargs={ - "right": right, - "labels": labels, - "precision": precision, - "include_lowest": include_lowest, - }, - ) - - return DatasetGroupBy( - self, grouper, squeeze=squeeze, restore_coord_dims=restore_coord_dims + right=right, + labels=labels, + precision=precision, + include_lowest=include_lowest, + squeeze=squeeze, + restore_coord_dims=restore_coord_dims, ) def weighted(self, weights: DataArray) -> DatasetWeighted: diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 0425da2ffc7..c907ebf9e0f 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -301,9 +301,7 @@ def _apply_loffset( class Grouper: - def __init__(self, group: T_Group | Hashable): - self.group: T_Group | Hashable = group - + def __init__(self): self.labels = None self._group_as_index: pd.Index | None = None @@ -343,26 +341,13 @@ def group_as_index(self) -> pd.Index: self._group_as_index = safe_cast_to_index(self.group1d) return self._group_as_index - def _resolve_group(self, obj: T_Xarray): - from xarray.core.dataarray import DataArray - - group = self.group - if not isinstance(group, (DataArray, IndexVariable)): - if not hashable(group): - raise TypeError( - "`group` must be an xarray.DataArray or the " - "name of an xarray variable or dimension. " - f"Received {group!r} instead." - ) - group = obj[group] - if len(group) == 0: - raise ValueError(f"{group.name} must not be empty") - if group.name not in obj._indexes and group.name in obj.dims: - # DummyGroups should not appear on groupby results - group = _DummyGroup(obj, group.name, group.coords) - - elif getattr(group, "name", None) is None: - group.name = "group" + def _resolve_group(self, obj: T_Xarray, group_name: Hashable): + group = obj[group_name] + if len(group) == 0: + raise ValueError(f"{group.name} must not be empty") + if group.name not in obj._indexes and group.name in obj.dims: + # DummyGroups should not appear on groupby results + group = _DummyGroup(obj, group.name, group.coords) self.group = group @@ -381,6 +366,14 @@ def _resolve_group(self, obj: T_Xarray): return self, stacked_obj + def copy(self, deep=False): + import copy + + if deep: + return copy.deepcopy(self) + else: + return copy.copy(self) + class UniqueGrouper(Grouper): def factorize(self, squeeze) -> None: @@ -433,7 +426,6 @@ def __init__(self, group, bins, cut_kwargs: Mapping | None): if cut_kwargs is None: cut_kwargs = {} - self.group = group self.bins = bins self.cut_kwargs = cut_kwargs @@ -459,7 +451,7 @@ def factorize(self, squeeze: bool) -> None: binned, getattr(self.group1d, "coords", None), name=new_dim_name ) self.unique_coord = IndexVariable( - self.group1d.name, unique_values, self.group.attrs + new_dim_name, unique_values, self.group.attrs ) self.codes = self.group1d.copy(data=codes) # TODO: support IntervalIndex in IndexVariable @@ -470,7 +462,6 @@ def factorize(self, squeeze: bool) -> None: class TimeResampleGrouper(Grouper): def __init__( self, - group, freq: str, closed: SideOptions | None, label: SideOptions | None, @@ -478,16 +469,19 @@ def __init__( offset: pd.Timedelta | datetime.timedelta | str | None, loffset: datetime.timedelta | str | None, ): - from xarray import CFTimeIndex - from xarray.core.resample_cftime import CFTimeGrouper - - self.group = group self.freq = freq self.closed = closed self.label = label self.origin = origin self.offset = offset self.loffset = loffset + + def _resolve_group(self, obj, group_name): + from xarray import CFTimeIndex + from xarray.core.resample_cftime import CFTimeGrouper + + group = obj[group_name] + self.group = group self._group_as_index = safe_cast_to_index(group) group_as_index = self._group_as_index @@ -514,6 +508,12 @@ def __init__( ) self.grouper: CFTimeGrouper | pd.Grouper = grouper + self.group1d, stacked_obj, self.stacked_dim, self.inserted_dims = _ensure_1d( + group, obj + ) + + return self, stacked_obj + def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: first_items, codes = self.first_items() full_index = first_items.index @@ -553,6 +553,25 @@ def factorize(self, squeeze: bool) -> None: self.codes = self.group.copy(data=codes) +def _validate_group(obj, group): + from xarray.core.dataarray import DataArray + + if isinstance(group, (DataArray, IndexVariable)): + name = group.name or "group" + newobj = obj.copy().assign_coords({name: group}) + else: + if not hashable(group): + raise TypeError( + "`group` must be an xarray.DataArray or the " + "name of an xarray variable or dimension. " + f"Received {group!r} instead." + ) + name = group + newobj = obj + + return newobj, name + + class GroupBy(Generic[T_Xarray]): """A object that implements the split-apply-combine pattern. @@ -596,8 +615,7 @@ class GroupBy(Generic[T_Xarray]): def __init__( self, obj: T_Xarray, - grouper: Grouper, - *, + groupers: Dict[Hashable, Grouper], squeeze: bool = False, restore_coord_dims: bool = True, ) -> None: @@ -615,7 +633,8 @@ def __init__( """ self._original_obj: T_Xarray = obj - grouper, obj = grouper._resolve_group(obj) + for group_name, grouper_ in groupers.items(): + grouper, obj = grouper_.copy()._resolve_group(obj, group_name) self._original_group = grouper.group self.groupers = (grouper,)