Skip to content

Commit

Permalink
REF: implement _ea_wrap_cython_operation (#38162)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Dec 2, 2020
1 parent c4c1dc3 commit 40ca2b9
Showing 1 changed file with 70 additions and 36 deletions.
106 changes: 70 additions & 36 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pandas._libs import NaT, iNaT, lib
import pandas._libs.groupby as libgroupby
import pandas._libs.reduction as libreduction
from pandas._typing import F, FrameOrSeries, Label, Shape
from pandas._typing import ArrayLike, F, FrameOrSeries, Label, Shape
from pandas.errors import AbstractMethodError
from pandas.util._decorators import cache_readonly

Expand Down Expand Up @@ -445,6 +445,68 @@ def _get_cython_func_and_vals(
raise
return func, values

def _disallow_invalid_ops(self, values: ArrayLike, how: str):
"""
Check if we can do this operation with our cython functions.
Raises
------
NotImplementedError
This is either not a valid function for this dtype, or
valid but not implemented in cython.
"""
dtype = values.dtype

if is_categorical_dtype(dtype) or is_sparse(dtype):
# categoricals are only 1d, so we
# are not setup for dim transforming
raise NotImplementedError(f"{dtype} dtype not supported")
elif is_datetime64_any_dtype(dtype):
# we raise NotImplemented if this is an invalid operation
# entirely, e.g. adding datetimes
if how in ["add", "prod", "cumsum", "cumprod"]:
raise NotImplementedError(
f"datetime64 type does not support {how} operations"
)
elif is_timedelta64_dtype(dtype):
if how in ["prod", "cumprod"]:
raise NotImplementedError(
f"timedelta64 type does not support {how} operations"
)

def _ea_wrap_cython_operation(
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
) -> Tuple[np.ndarray, Optional[List[str]]]:
"""
If we have an ExtensionArray, unwrap, call _cython_operation, and
re-wrap if appropriate.
"""
# TODO: general case implementation overrideable by EAs.
orig_values = values

if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype):
# All of the functions implemented here are ordinal, so we can
# operate on the tz-naive equivalents
values = values.view("M8[ns]")
res_values, names = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
)
res_values = res_values.astype("i8", copy=False)
# FIXME: this is wrong for rank, but not tested.
result = type(orig_values)._simple_new(res_values, dtype=orig_values.dtype)
return result, names

elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype):
# IntegerArray or BooleanArray
values = ensure_int_or_float(values)
res_values, names = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
)
result = maybe_cast_result(result=res_values, obj=orig_values, how=how)
return result, names

raise NotImplementedError(values.dtype)

def _cython_operation(
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
) -> Tuple[np.ndarray, Optional[List[str]]]:
Expand All @@ -454,8 +516,8 @@ def _cython_operation(
Names is only useful when dealing with 2D results, like ohlc
(see self._name_functions).
"""
assert kind in ["transform", "aggregate"]
orig_values = values
assert kind in ["transform", "aggregate"]

if values.ndim > 2:
raise NotImplementedError("number of dimensions is currently limited to 2")
Expand All @@ -466,30 +528,12 @@ def _cython_operation(

# can we do this operation with our cython functions
# if not raise NotImplementedError
self._disallow_invalid_ops(values, how)

# we raise NotImplemented if this is an invalid operation
# entirely, e.g. adding datetimes

# categoricals are only 1d, so we
# are not setup for dim transforming
if is_categorical_dtype(values.dtype) or is_sparse(values.dtype):
raise NotImplementedError(f"{values.dtype} dtype not supported")
elif is_datetime64_any_dtype(values.dtype):
if how in ["add", "prod", "cumsum", "cumprod"]:
raise NotImplementedError(
f"datetime64 type does not support {how} operations"
)
elif is_timedelta64_dtype(values.dtype):
if how in ["prod", "cumprod"]:
raise NotImplementedError(
f"timedelta64 type does not support {how} operations"
)

if is_datetime64tz_dtype(values.dtype):
# Cast to naive; we'll cast back at the end of the function
# TODO: possible need to reshape?
# TODO(EA2D):kludge can be avoided when 2D EA is allowed.
values = values.view("M8[ns]")
if is_extension_array_dtype(values.dtype):
return self._ea_wrap_cython_operation(
kind, values, how, axis, min_count, **kwargs
)

is_datetimelike = needs_i8_conversion(values.dtype)
is_numeric = is_numeric_dtype(values.dtype)
Expand Down Expand Up @@ -573,19 +617,9 @@ def _cython_operation(
if swapped:
result = result.swapaxes(0, axis)

if is_datetime64tz_dtype(orig_values.dtype) or is_period_dtype(
orig_values.dtype
):
# We need to use the constructors directly for these dtypes
# since numpy won't recognize them
# https://github.com/pandas-dev/pandas/issues/31471
result = type(orig_values)(result.astype(np.int64), dtype=orig_values.dtype)
elif is_datetimelike and kind == "aggregate":
if is_datetimelike and kind == "aggregate":
result = result.astype(orig_values.dtype)

if is_extension_array_dtype(orig_values.dtype):
result = maybe_cast_result(result=result, obj=orig_values, how=how)

return result, names

def _aggregate(
Expand Down

0 comments on commit 40ca2b9

Please sign in to comment.