Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: implement _ea_wrap_cython_operation #38162

Merged
merged 3 commits into from
Dec 2, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 70 additions & 36 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pandas._libs import NaT, iNaT, lib
import pandas._libs.groupby as libgroupby
import pandas._libs.reduction as libreduction
from pandas._typing import F, FrameOrSeries, Label, Shape
from pandas._typing import ArrayLike, F, FrameOrSeries, Label, Shape
from pandas.errors import AbstractMethodError
from pandas.util._decorators import cache_readonly

Expand Down Expand Up @@ -445,6 +445,68 @@ def _get_cython_func_and_vals(
raise
return func, values

def _disallow_invalid_ops(self, values: ArrayLike, how: str):
"""
Check if we can do this operation with our cython functions.

Raises
------
NotImplementedError
This is either not a valid function for this dtype, or
valid but not implemented in cython.
"""
dtype = values.dtype

if is_categorical_dtype(dtype) or is_sparse(dtype):
# categoricals are only 1d, so we
# are not setup for dim transforming
raise NotImplementedError(f"{dtype} dtype not supported")
elif is_datetime64_any_dtype(dtype):
# we raise NotImplemented if this is an invalid operation
# entirely, e.g. adding datetimes
if how in ["add", "prod", "cumsum", "cumprod"]:
raise NotImplementedError(
f"datetime64 type does not support {how} operations"
)
elif is_timedelta64_dtype(dtype):
if how in ["prod", "cumprod"]:
raise NotImplementedError(
f"timedelta64 type does not support {how} operations"
)

def _ea_wrap_cython_operation(
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
) -> Tuple[np.ndarray, Optional[List[str]]]:
"""
If we have an ExtensionArray, unwrap, call _cython_operation, and
re-wrap if appropriate.
"""
# TODO: general case implementation overrideable by EAs.
orig_values = values

if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype):
# All of the functions implemented here are ordinal, so we can
# operate on the tz-naive equivalents
values = values.view("M8[ns]")
res_values, names = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
)
res_values = res_values.astype("i8", copy=False)
# FIXME: this is wrong for rank, but not tested.
result = type(orig_values)._simple_new(res_values, dtype=orig_values.dtype)
return result, names

elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype):
# IntegerArray or BooleanArray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably also handle FloatingArray?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, ive got a branch doing that right now, hopefully will make a PR today.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, yes please try to do that shortly, as that should go into 1.2 ideally (before this PR, FloatArray was going through the cython groupby algos, but now not anymore)

values = ensure_int_or_float(values)
res_values, names = self._cython_operation(
kind, values, how, axis, min_count, **kwargs
)
result = maybe_cast_result(result=res_values, obj=orig_values, how=how)
return result, names

raise NotImplementedError(values.dtype)

def _cython_operation(
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
) -> Tuple[np.ndarray, Optional[List[str]]]:
Expand All @@ -454,8 +516,8 @@ def _cython_operation(
Names is only useful when dealing with 2D results, like ohlc
(see self._name_functions).
"""
assert kind in ["transform", "aggregate"]
orig_values = values
assert kind in ["transform", "aggregate"]

if values.ndim > 2:
raise NotImplementedError("number of dimensions is currently limited to 2")
Expand All @@ -466,30 +528,12 @@ def _cython_operation(

# can we do this operation with our cython functions
# if not raise NotImplementedError
self._disallow_invalid_ops(values, how)

# we raise NotImplemented if this is an invalid operation
# entirely, e.g. adding datetimes

# categoricals are only 1d, so we
# are not setup for dim transforming
if is_categorical_dtype(values.dtype) or is_sparse(values.dtype):
raise NotImplementedError(f"{values.dtype} dtype not supported")
elif is_datetime64_any_dtype(values.dtype):
if how in ["add", "prod", "cumsum", "cumprod"]:
raise NotImplementedError(
f"datetime64 type does not support {how} operations"
)
elif is_timedelta64_dtype(values.dtype):
if how in ["prod", "cumprod"]:
raise NotImplementedError(
f"timedelta64 type does not support {how} operations"
)

if is_datetime64tz_dtype(values.dtype):
# Cast to naive; we'll cast back at the end of the function
# TODO: possible need to reshape?
# TODO(EA2D):kludge can be avoided when 2D EA is allowed.
values = values.view("M8[ns]")
if is_extension_array_dtype(values.dtype):
return self._ea_wrap_cython_operation(
kind, values, how, axis, min_count, **kwargs
)

is_datetimelike = needs_i8_conversion(values.dtype)
is_numeric = is_numeric_dtype(values.dtype)
Expand Down Expand Up @@ -573,19 +617,9 @@ def _cython_operation(
if swapped:
result = result.swapaxes(0, axis)

if is_datetime64tz_dtype(orig_values.dtype) or is_period_dtype(
orig_values.dtype
):
# We need to use the constructors directly for these dtypes
# since numpy won't recognize them
# https://github.com/pandas-dev/pandas/issues/31471
result = type(orig_values)(result.astype(np.int64), dtype=orig_values.dtype)
elif is_datetimelike and kind == "aggregate":
if is_datetimelike and kind == "aggregate":
result = result.astype(orig_values.dtype)

if is_extension_array_dtype(orig_values.dtype):
result = maybe_cast_result(result=result, obj=orig_values, how=how)

return result, names

def _aggregate(
Expand Down