Skip to content
forked from pydata/xarray

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Mar 30, 2023
1 parent 1168ab7 commit c905b74
Showing 1 changed file with 34 additions and 26 deletions.
60 changes: 34 additions & 26 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from xarray.core.utils import Frozen

GroupKey = Any
GroupIndex = int | slice | list[int]

T_GroupIndicesListInt = list[list[int]]
T_GroupIndices = Union[T_GroupIndicesListInt, list[slice], np.ndarray]
Expand Down Expand Up @@ -129,11 +130,11 @@ def _dummy_copy(xarray_obj):
return res


def _is_one_or_none(obj):
def _is_one_or_none(obj) -> bool:
return obj == 1 or obj is None


def _consolidate_slices(slices):
def _consolidate_slices(slices: list[slice]) -> list[slice]:
"""Consolidate adjacent slices in a list of slices."""
result = []
last_slice = slice(None)
Expand Down Expand Up @@ -191,7 +192,6 @@ def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None:
self.name = name
self.coords = coords
self.size = obj.sizes[name]
self.dataarray = obj[name]

@property
def dims(self) -> tuple[Hashable]:
Expand Down Expand Up @@ -228,6 +228,13 @@ def __getitem__(self, key):
def copy(self, deep: bool = True, data: Any = None):
raise NotImplementedError

def as_dataarray(self) -> DataArray:
from xarray.core.dataarray import DataArray

return DataArray(
data=self.data, dims=(self.name,), coords=self.coords, name=self.name
)


T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup])

Expand Down Expand Up @@ -294,14 +301,16 @@ def _apply_loffset(


class Grouper:
def __init__(self, group: T_Group):
self.group : T_Group | None = group
self.codes : np.ndarry | None = None
def __init__(self, group: T_Group | Hashable):
self.group: T_Group | Hashable = group

self.labels = None
self.group_indices : list[list[int, ...]] | None= None
self.unique_coord = None
self.full_index : pd.Index | None = None
self._group_as_index = None
self._group_as_index: pd.Index | None = None

self.codes: DataArray
self.group_indices: list[int] | list[slice] | list[list[int]]
self.unique_coord: IndexVariable | _DummyGroup
self.full_index: pd.Index

@property
def name(self) -> Hashable:
Expand Down Expand Up @@ -334,10 +343,9 @@ def group_as_index(self) -> pd.Index:
self._group_as_index = safe_cast_to_index(self.group1d)
return self._group_as_index

def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None:
def _resolve_group(self, obj: T_Xarray):
from xarray.core.dataarray import DataArray

group: T_Group
group = self.group
if not isinstance(group, (DataArray, IndexVariable)):
if not hashable(group):
Expand All @@ -346,15 +354,14 @@ def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None:
"name of an xarray variable or dimension. "
f"Received {group!r} instead."
)
group_da : T_DataArray = obj[group]
if len(group_da) == 0:
raise ValueError(f"{group_da.name} must not be empty")

if group_da.name not in obj.coords and group_da.name in obj.dims:
group = obj[group]
if len(group) == 0:
raise ValueError(f"{group.name} must not be empty")
if group.name not in obj._indexes and group.name in obj.dims:
# DummyGroups should not appear on groupby results
group = _DummyGroup(obj, group.name, group.coords)

if getattr(group, "name", None) is None:
elif getattr(group, "name", None) is None:
group.name = "group"

self.group = group
Expand Down Expand Up @@ -408,10 +415,10 @@ def _factorize_dummy(self, squeeze) -> None:
# equivalent to: group_indices = group_indices.reshape(-1, 1)
self.group_indices = [slice(i, i + 1) for i in range(size)]
else:
self.group_indices = np.arange(size)
self.group_indices = list(range(size))
codes = np.arange(size)
if isinstance(self.group, _DummyGroup):
self.codes = self.group.dataarray.copy(data=codes)
self.codes = self.group.as_dataarray().copy(data=codes)
else:
self.codes = self.group.copy(data=codes)
self.unique_coord = self.group
Expand Down Expand Up @@ -489,7 +496,7 @@ def __init__(
raise ValueError("index must be monotonic for resampling")

if isinstance(group_as_index, CFTimeIndex):
self.grouper = CFTimeGrouper(
grouper = CFTimeGrouper(
freq=self.freq,
closed=self.closed,
label=self.label,
Expand All @@ -498,15 +505,16 @@ def __init__(
loffset=self.loffset,
)
else:
self.grouper = pd.Grouper(
grouper = pd.Grouper(
freq=self.freq,
closed=self.closed,
label=self.label,
origin=self.origin,
offset=self.offset,
)
self.grouper: CFTimeGrouper | pd.Grouper = grouper

def _get_index_and_items(self):
def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]:
first_items, codes = self.first_items()
full_index = first_items.index
if first_items.isnull().any():
Expand All @@ -515,7 +523,7 @@ def _get_index_and_items(self):
full_index = full_index.rename("__resample_dim__")
return full_index, first_items, codes

def first_items(self):
def first_items(self) -> tuple[pd.Series, np.ndarray]:
from xarray import CFTimeIndex

if isinstance(self.group_as_index, CFTimeIndex):
Expand Down Expand Up @@ -670,7 +678,7 @@ def reduce(
raise NotImplementedError()

@property
def groups(self) -> dict[GroupKey, slice | int | list[int]]:
def groups(self) -> dict[GroupKey, GroupIndex]:
"""
Mapping from group labels to indices. The indices can be used to index the underlying object.
"""
Expand Down Expand Up @@ -735,7 +743,7 @@ def _binary_op(self, other, f, reflexive=False):
dims = group.dims

if isinstance(group, _DummyGroup):
group = coord = group.dataarray
group = coord = group.as_dataarray()
else:
coord = grouper.unique_coord
if not isinstance(coord, DataArray):
Expand Down

0 comments on commit c905b74

Please sign in to comment.