Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: implement groupby std, min, max as EA methods #51116

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def group_any_all(
mask: np.ndarray, # const uint8_t[::1]
val_test: Literal["any", "all"],
skipna: bool,
nullable: bool = ...,
) -> None: ...
def group_sum(
out: np.ndarray, # complexfloatingintuint_t[:, ::1]
Expand Down
76 changes: 75 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@

import numpy as np

from pandas._libs import lib
from pandas._libs import (
groupby as libgroupby,
lib,
)
from pandas._typing import (
ArrayLike,
AstypeArg,
Expand Down Expand Up @@ -58,6 +61,7 @@
is_datetime64_dtype,
is_dtype_equal,
is_list_like,
is_object_dtype,
is_scalar,
is_timedelta64_dtype,
pandas_dtype,
Expand Down Expand Up @@ -1688,6 +1692,76 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):

return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs)

# ------------------------------------------------------------------------
# GroupBy Methods

def groupby_any_all(
self,
*,
ngroups: int,
ids: npt.NDArray[np.intp],
val_test: Literal["any", "all"],
skipna: bool,
):
ncols = 1 if self.ndim == 1 else self.shape[0]
result = np.zeros((ngroups, ncols), dtype=np.int8)

obj = self
if is_object_dtype(self.dtype) and skipna:
# GH#37501: don't raise on pd.NA when skipna=True
self_mask = self.isna()
if self_mask.any():
# mask on original values computed separately
obj = obj.copy()
obj[self_mask] = True

vals = obj.astype(bool, copy=False).view(np.int8)
mask = np.asarray(self.isna()).view(np.uint8)

# Call func to modify result in place
libgroupby.group_any_all(
out=result,
labels=ids,
values=np.atleast_2d(vals).T,
mask=np.atleast_2d(mask).T,
val_test=val_test,
skipna=skipna,
nullable=False,
)

if self.ndim == 1:
assert result.shape[1] == 1, result.shape
result = result[:, 0]

result = result.astype(bool, copy=False)
return result.T

def groupby_std(self, *, ngroups: int, ids: npt.NDArray[np.intp], ddof: int):
cython_dtype = np.dtype(np.float64)

ncols = 1 if self.ndim == 1 else self.shape[0]

result = np.zeros((ngroups, ncols), dtype=cython_dtype)
counts = np.zeros(ngroups, dtype=np.int64)

vals = self.astype(cython_dtype, copy=False)

# Call func to modify result in place
libgroupby.group_var(
out=result,
labels=ids,
values=np.atleast_2d(vals).T,
counts=counts,
ddof=ddof,
)

if self.ndim == 1:
assert result.shape[1] == 1, result.shape
result = result[:, 0]

result = np.sqrt(result)
return result.T


class ExtensionArraySupportsAnyAll(ExtensionArray):
def any(self, *, skipna: bool = True) -> bool:
Expand Down
62 changes: 62 additions & 0 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np

from pandas._libs import (
groupby as libgroupby,
lib,
missing as libmissing,
)
Expand Down Expand Up @@ -1383,3 +1384,64 @@ def _accumulate(
data, mask = op(data, mask, skipna=skipna, **kwargs)

return type(self)(data, mask, copy=False)

# ------------------------------------------------------------------------
# GroupBy Methods

def groupby_any_all(
self,
*,
ngroups: int,
ids: npt.NDArray[np.intp],
val_test: Literal["any", "all"],
skipna: bool,
):
from pandas.core.arrays import BooleanArray

result = np.zeros(ngroups * 1, dtype=np.int8).reshape(-1, 1)
jbrockmendel marked this conversation as resolved.
Show resolved Hide resolved
vals = self._data.astype(bool, copy=False).view(np.int8).reshape(-1, 1)
mask = self.isna().view(np.uint8).reshape(-1, 1)

# Call func to modify result in place
libgroupby.group_any_all(
out=result,
labels=ids,
values=vals,
mask=mask,
val_test=val_test,
skipna=skipna,
nullable=True,
)

assert result.shape[1] == 1, result.shape
result = result[:, 0]

return BooleanArray(result.astype(bool, copy=False), result == -1)

def groupby_std(self, *, ngroups: int, ids: npt.NDArray[np.intp], ddof: int):
from pandas.core.arrays import FloatingArray

result = np.zeros(ngroups * 1, dtype=np.float64).reshape(-1, 1)
counts = np.zeros(ngroups, dtype=np.int64)
vals = self._data.astype(np.float64, copy=False).reshape(-1, 1)
mask = self.isna().view(np.uint8).reshape(-1, 1)
result_mask = np.zeros(result.shape, dtype=np.bool_)

# Call func to modify result in place
libgroupby.group_var(
out=result,
labels=ids,
values=vals,
mask=mask,
counts=counts,
result_mask=result_mask,
ddof=ddof,
)

assert result.shape[1] == 1, result.shape
result = result[:, 0]

assert result_mask.shape[1] == 1, result_mask.shape
result_mask = result_mask[:, 0]

return FloatingArray(np.sqrt(result), result_mask.view(np.bool_))
148 changes: 18 additions & 130 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ class providing the base-class of operations.
from pandas.core._numba import executor
from pandas.core.arrays import (
BaseMaskedArray,
BooleanArray,
Categorical,
ExtensionArray,
FloatingArray,
PandasArray,
)
from pandas.core.base import (
PandasObject,
Expand Down Expand Up @@ -1799,36 +1799,8 @@ def _bool_agg(self, val_test: Literal["any", "all"], skipna: bool):
"""
Shared func to call any / all Cython GroupBy implementations.
"""

def objs_to_bool(vals: ArrayLike) -> tuple[np.ndarray, type]:
if is_object_dtype(vals.dtype) and skipna:
# GH#37501: don't raise on pd.NA when skipna=True
mask = isna(vals)
if mask.any():
# mask on original values computed separately
vals = vals.copy()
vals[mask] = True
elif isinstance(vals, BaseMaskedArray):
vals = vals._data
vals = vals.astype(bool, copy=False)
return vals.view(np.int8), bool

def result_to_bool(
result: np.ndarray,
inference: type,
nullable: bool = False,
) -> ArrayLike:
if nullable:
return BooleanArray(result.astype(bool, copy=False), result == -1)
else:
return result.astype(inference, copy=False)

return self._get_cythonized_result(
libgroupby.group_any_all,
numeric_only=False,
cython_dtype=np.dtype(np.int8),
pre_processing=objs_to_bool,
post_processing=result_to_bool,
val_test=val_test,
skipna=skipna,
)
Expand Down Expand Up @@ -2105,27 +2077,8 @@ def std(
f"numeric_only={numeric_only} and dtype {self.obj.dtype}"
)

def _preprocessing(values):
if isinstance(values, BaseMaskedArray):
return values._data, None
return values, None

def _postprocessing(
vals, inference, nullable: bool = False, result_mask=None
) -> ArrayLike:
if nullable:
if result_mask.ndim == 2:
result_mask = result_mask[:, 0]
return FloatingArray(np.sqrt(vals), result_mask.view(np.bool_))
return np.sqrt(vals)

result = self._get_cythonized_result(
libgroupby.group_var,
cython_dtype=np.dtype(np.float64),
numeric_only=numeric_only,
needs_counts=True,
pre_processing=_preprocessing,
post_processing=_postprocessing,
ddof=ddof,
how="std",
)
Expand Down Expand Up @@ -3666,12 +3619,7 @@ def cummax(
@final
def _get_cythonized_result(
self,
base_func: Callable,
cython_dtype: np.dtype,
numeric_only: bool = False,
needs_counts: bool = False,
pre_processing=None,
post_processing=None,
how: str = "any_all",
**kwargs,
):
Expand All @@ -3680,27 +3628,8 @@ def _get_cythonized_result(

Parameters
----------
base_func : callable, Cythonized function to be called
cython_dtype : np.dtype
Type of the array that will be modified by the Cython call.
numeric_only : bool, default False
Whether only numeric datatypes should be computed
needs_counts : bool, default False
Whether the counts should be a part of the Cython call
pre_processing : function, default None
Function to be applied to `values` prior to passing to Cython.
Function should return a tuple where the first element is the
values to be passed to Cython and the second element is an optional
type which the values should be converted to after being returned
by the Cython operation. This function is also responsible for
raising a TypeError if the values have an invalid type. Raises
if `needs_values` is False.
post_processing : function, default None
Function to be applied to result of Cython function. Should accept
an array of values as the first argument and type inferences as its
second argument, i.e. the signature should be
(ndarray, Type). If `needs_nullable=True`, a third argument should be
`nullable`, to allow for processing specific to nullable values.
how : str, default any_all
Determines if any/all cython interface or std interface is used.
**kwargs : dict
Expand All @@ -3710,71 +3639,30 @@ def _get_cythonized_result(
-------
`Series` or `DataFrame` with filled values
"""
if post_processing and not callable(post_processing):
raise ValueError("'post_processing' must be a callable!")
if pre_processing and not callable(pre_processing):
raise ValueError("'pre_processing' must be a callable!")

grouper = self.grouper

ids, _, ngroups = grouper.group_info

base_func = partial(base_func, labels=ids)

def blk_func(values: ArrayLike) -> ArrayLike:
values = values.T
ncols = 1 if values.ndim == 1 else values.shape[1]

result: ArrayLike
result = np.zeros(ngroups * ncols, dtype=cython_dtype)
result = result.reshape((ngroups, ncols))

func = partial(base_func, out=result)

inferences = None

if needs_counts:
counts = np.zeros(ngroups, dtype=np.int64)
func = partial(func, counts=counts)

vals = values
if pre_processing:
vals, inferences = pre_processing(vals)

vals = vals.astype(cython_dtype, copy=False)
if vals.ndim == 1:
vals = vals.reshape((-1, 1))
func = partial(func, values=vals)

if how != "std" or isinstance(values, BaseMaskedArray):
mask = isna(values).view(np.uint8)
if mask.ndim == 1:
mask = mask.reshape(-1, 1)
func = partial(func, mask=mask)

if how != "std":
is_nullable = isinstance(values, BaseMaskedArray)
func = partial(func, nullable=is_nullable)

elif isinstance(values, BaseMaskedArray):
result_mask = np.zeros(result.shape, dtype=np.bool_)
func = partial(func, result_mask=result_mask)

func(**kwargs) # Call func to modify result in place

if values.ndim == 1:
assert result.shape[1] == 1, result.shape
result = result[:, 0]

if post_processing:
pp_kwargs: dict[str, bool | np.ndarray] = {}
pp_kwargs["nullable"] = isinstance(values, BaseMaskedArray)
if how == "std" and pp_kwargs["nullable"]:
pp_kwargs["result_mask"] = result_mask

result = post_processing(result, inferences, **pp_kwargs)
if isinstance(values, ExtensionArray):
if how == "std":
return values.groupby_std(
ngroups=ngroups, ids=ids, ddof=kwargs["ddof"]
)
elif how == "any_all":
return values.groupby_any_all(
ngroups=ngroups,
ids=ids,
skipna=kwargs["skipna"],
val_test=kwargs["val_test"],
)
else:
raise NotImplementedError

return result.T
arr = PandasArray(values)
result = blk_func(arr)
return result

obj = self._obj_with_exclusions

Expand Down