diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 2542b2d7071cb..db0f0dbe17386 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -24,7 +24,7 @@ from pandas._libs import NaT, iNaT, lib import pandas._libs.groupby as libgroupby import pandas._libs.reduction as libreduction -from pandas._typing import F, FrameOrSeries, Label, Shape +from pandas._typing import ArrayLike, F, FrameOrSeries, Label, Shape from pandas.errors import AbstractMethodError from pandas.util._decorators import cache_readonly @@ -445,6 +445,68 @@ def _get_cython_func_and_vals( raise return func, values + def _disallow_invalid_ops(self, values: ArrayLike, how: str): + """ + Check if we can do this operation with our cython functions. + + Raises + ------ + NotImplementedError + This is either not a valid function for this dtype, or + valid but not implemented in cython. + """ + dtype = values.dtype + + if is_categorical_dtype(dtype) or is_sparse(dtype): + # categoricals are only 1d, so we + # are not setup for dim transforming + raise NotImplementedError(f"{dtype} dtype not supported") + elif is_datetime64_any_dtype(dtype): + # we raise NotImplemented if this is an invalid operation + # entirely, e.g. adding datetimes + if how in ["add", "prod", "cumsum", "cumprod"]: + raise NotImplementedError( + f"datetime64 type does not support {how} operations" + ) + elif is_timedelta64_dtype(dtype): + if how in ["prod", "cumprod"]: + raise NotImplementedError( + f"timedelta64 type does not support {how} operations" + ) + + def _ea_wrap_cython_operation( + self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs + ) -> Tuple[np.ndarray, Optional[List[str]]]: + """ + If we have an ExtensionArray, unwrap, call _cython_operation, and + re-wrap if appropriate. + """ + # TODO: general case implementation overrideable by EAs. + orig_values = values + + if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype): + # All of the functions implemented here are ordinal, so we can + # operate on the tz-naive equivalents + values = values.view("M8[ns]") + res_values, names = self._cython_operation( + kind, values, how, axis, min_count, **kwargs + ) + res_values = res_values.astype("i8", copy=False) + # FIXME: this is wrong for rank, but not tested. + result = type(orig_values)._simple_new(res_values, dtype=orig_values.dtype) + return result, names + + elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype): + # IntegerArray or BooleanArray + values = ensure_int_or_float(values) + res_values, names = self._cython_operation( + kind, values, how, axis, min_count, **kwargs + ) + result = maybe_cast_result(result=res_values, obj=orig_values, how=how) + return result, names + + raise NotImplementedError(values.dtype) + def _cython_operation( self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs ) -> Tuple[np.ndarray, Optional[List[str]]]: @@ -454,8 +516,8 @@ def _cython_operation( Names is only useful when dealing with 2D results, like ohlc (see self._name_functions). """ - assert kind in ["transform", "aggregate"] orig_values = values + assert kind in ["transform", "aggregate"] if values.ndim > 2: raise NotImplementedError("number of dimensions is currently limited to 2") @@ -466,30 +528,12 @@ def _cython_operation( # can we do this operation with our cython functions # if not raise NotImplementedError + self._disallow_invalid_ops(values, how) - # we raise NotImplemented if this is an invalid operation - # entirely, e.g. adding datetimes - - # categoricals are only 1d, so we - # are not setup for dim transforming - if is_categorical_dtype(values.dtype) or is_sparse(values.dtype): - raise NotImplementedError(f"{values.dtype} dtype not supported") - elif is_datetime64_any_dtype(values.dtype): - if how in ["add", "prod", "cumsum", "cumprod"]: - raise NotImplementedError( - f"datetime64 type does not support {how} operations" - ) - elif is_timedelta64_dtype(values.dtype): - if how in ["prod", "cumprod"]: - raise NotImplementedError( - f"timedelta64 type does not support {how} operations" - ) - - if is_datetime64tz_dtype(values.dtype): - # Cast to naive; we'll cast back at the end of the function - # TODO: possible need to reshape? - # TODO(EA2D):kludge can be avoided when 2D EA is allowed. - values = values.view("M8[ns]") + if is_extension_array_dtype(values.dtype): + return self._ea_wrap_cython_operation( + kind, values, how, axis, min_count, **kwargs + ) is_datetimelike = needs_i8_conversion(values.dtype) is_numeric = is_numeric_dtype(values.dtype) @@ -573,19 +617,9 @@ def _cython_operation( if swapped: result = result.swapaxes(0, axis) - if is_datetime64tz_dtype(orig_values.dtype) or is_period_dtype( - orig_values.dtype - ): - # We need to use the constructors directly for these dtypes - # since numpy won't recognize them - # https://github.com/pandas-dev/pandas/issues/31471 - result = type(orig_values)(result.astype(np.int64), dtype=orig_values.dtype) - elif is_datetimelike and kind == "aggregate": + if is_datetimelike and kind == "aggregate": result = result.astype(orig_values.dtype) - if is_extension_array_dtype(orig_values.dtype): - result = maybe_cast_result(result=result, obj=orig_values, how=how) - return result, names def _aggregate(