Skip to content

Commit

Permalink
Remove internal use of IndexVariable
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jun 14, 2024
1 parent 599b779 commit fc85ba2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 20 deletions.
38 changes: 29 additions & 9 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
codes: DataArray = field(init=False)
full_index: pd.Index = field(init=False)
group_indices: T_GroupIndices = field(init=False)
unique_coord: IndexVariable | _DummyGroup = field(init=False)
unique_coord: Variable | _DummyGroup = field(init=False)

# _ensure_1d:
group1d: T_Group = field(init=False)
Expand Down Expand Up @@ -328,14 +328,18 @@ def __post_init__(self) -> None:

@property
def name(self) -> Hashable:
"""Name for the grouped coordinate after reduction."""
# the name has to come from unique_coord because we need `_bins` suffix for BinGrouper
return self.unique_coord.name
(name,) = self.unique_coord.dims
return name

@property
def size(self) -> int:
"""Number of groups."""
return len(self)

def __len__(self) -> int:
"""Number of groups."""
return len(self.full_index)

@property
Expand All @@ -358,8 +362,8 @@ def factorize(self) -> None:
]
if encoded.unique_coord is None:
unique_values = self.full_index[np.unique(encoded.codes)]
self.unique_coord = IndexVariable(
self.codes.name, unique_values, attrs=self.group.attrs
self.unique_coord = Variable(
dims=self.codes.name, data=unique_values, attrs=self.group.attrs
)
else:
self.unique_coord = encoded.unique_coord
Expand Down Expand Up @@ -620,6 +624,8 @@ def _iter_grouped(self, warn_squeeze=True) -> Iterator[T_Xarray]:
yield self._obj.isel({self._group_dim: indices})

def _infer_concat_args(self, applied_example):
from xarray.core.groupers import BinGrouper

(grouper,) = self.groupers
if self._group_dim in applied_example.dims:
coord = grouper.group1d
Expand All @@ -628,7 +634,10 @@ def _infer_concat_args(self, applied_example):
coord = grouper.unique_coord
positions = None
(dim,) = coord.dims
if isinstance(coord, _DummyGroup):
if isinstance(grouper.group, _DummyGroup) and not isinstance(
grouper.grouper, BinGrouper
):
# When binning we actually do set the index
coord = None
coord = getattr(coord, "variable", coord)
return coord, dim, positions
Expand All @@ -645,12 +654,20 @@ def _binary_op(self, other, f, reflexive=False):
codes = self._codes
dims = group.dims

if isinstance(group, _DummyGroup):
if isinstance(grouper.group, _DummyGroup):
group = coord = group.to_dataarray()
else:
coord = grouper.unique_coord
if not isinstance(coord, DataArray):
coord = DataArray(grouper.unique_coord)
if isinstance(coord, Variable):
assert coord.ndim == 1
(coord_dim,) = coord.dims
# TODO: explicitly create Index here
coord = DataArray(
dims=coord_dim,
data=coord.data,
attrs=coord.attrs,
coords={coord_dim: coord.data},
)
name = grouper.name

if not isinstance(other, (Dataset, DataArray)):
Expand Down Expand Up @@ -848,7 +865,10 @@ def _flox_reduce(
# in the grouped variable
group_dims = grouper.group.dims
if set(group_dims).issubset(set(parsed_dim)):
result[grouper.name] = output_index
new_coord = Variable(
dims=grouper.name, data=np.array(output_index), attrs=self._codes.attrs
)
result = result.assign_coords({grouper.name: new_coord})
result = result.drop_vars(unindexed_dims)

# broadcast and restore non-numeric data variables (backcompat)
Expand Down
37 changes: 26 additions & 11 deletions xarray/core/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from xarray.core.resample_cftime import CFTimeGrouper
from xarray.core.types import DatetimeLike, SideOptions, T_GroupIndices
from xarray.core.utils import emit_user_level_warning
from xarray.core.variable import IndexVariable
from xarray.core.variable import Variable

__all__ = [
"EncodedGroups",
Expand Down Expand Up @@ -55,7 +55,17 @@ class EncodedGroups:
codes: DataArray
full_index: pd.Index
group_indices: T_GroupIndices | None = field(default=None)
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)
unique_coord: Variable | _DummyGroup | None = field(default=None)

def __post_init__(self):
assert isinstance(self.codes, DataArray)
if self.codes.name is None:
raise ValueError("Please set a name on the array you are grouping by.")
assert isinstance(self.full_index, pd.Index)
assert (
isinstance(self.unique_coord, (Variable, _DummyGroup))
or self.unique_coord is None
)


class Grouper(ABC):
Expand Down Expand Up @@ -134,10 +144,10 @@ def _factorize_unique(self) -> EncodedGroups:
"Failed to group data. Are you grouping by a variable that is all NaN?"
)
codes = self.group.copy(data=codes_)
unique_coord = IndexVariable(
self.group.name, unique_values, attrs=self.group.attrs
unique_coord = Variable(
dims=codes.name, data=unique_values, attrs=self.group.attrs
)
full_index = unique_coord
full_index = pd.Index(unique_values)

return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
Expand All @@ -152,12 +162,13 @@ def _factorize_dummy(self) -> EncodedGroups:
size_range = np.arange(size)
if isinstance(self.group, _DummyGroup):
codes = self.group.to_dataarray().copy(data=size_range)
unique_coord = self.group
full_index = pd.RangeIndex(self.group.size)
else:
codes = self.group.copy(data=size_range)
unique_coord = self.group
full_index = IndexVariable(
self.group.name, unique_coord.values, self.group.attrs
)
unique_coord = self.group.variable
full_index = pd.Index(unique_coord.data)

return EncodedGroups(
codes=codes,
group_indices=group_indices,
Expand Down Expand Up @@ -201,7 +212,9 @@ def factorize(self, group) -> EncodedGroups:
codes = DataArray(
binned_codes, getattr(group, "coords", None), name=new_dim_name
)
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
unique_coord = Variable(
dims=new_dim_name, data=pd.Index(unique_values), attrs=group.attrs
)
return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
)
Expand Down Expand Up @@ -318,7 +331,9 @@ def factorize(self, group) -> EncodedGroups:
]
group_indices += [slice(sbins[-1], None)]

unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
unique_coord = Variable(
dims=group.name, data=first_items.index, attrs=group.attrs
)
codes = group.copy(data=codes_)

return EncodedGroups(
Expand Down

0 comments on commit fc85ba2

Please sign in to comment.