Skip to content

REF: Groupby.quantile allow EA dispatch #51003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
92 changes: 92 additions & 0 deletions pandas/core/array_algos/quantile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

from functools import partial
from typing import Literal

import numpy as np

from pandas._libs import groupby as libgroupby
from pandas._typing import (
ArrayLike,
Scalar,
Expand All @@ -15,6 +19,94 @@
)


def groupby_quantile_ndim_compat(
*,
qs: npt.NDArray[np.float64],
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
ngroups: int,
ids: npt.NDArray[np.intp],
labels_for_lexsort: npt.NDArray[np.intp],
npy_vals: np.ndarray,
mask: npt.NDArray[np.bool_],
result_mask: npt.NDArray[np.bool_] | None,
) -> np.ndarray:
"""
Compatibility layer to handle either 1D arrays or 2D ndarrays in
GroupBy.quantile. Located here to be available to
ExtensionArray._groupby_quantile for dispatching after casting to numpy.

Parameters
----------
qs : np.ndarray[float64]
Values between 0 and 1 providing the quantile(s) to compute.
interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'}
Method to use when the desired quantile falls between two points.
ngroups : int
The number of groupby groups.
ids : np.ndarray[intp]
Group labels.
labels_for_lexsort : np.ndarray[intp]
Group labels, but with -1s moved moved to the end to sort NAs last.
npy_vals : np.ndarray
The values for which we are computing quantiles.
mask : np.ndarray[bool]
Locations to treat as NA.
result_mask : np.ndarray[bool] or None
If present, set to True for locations that should be treated as missing
a result. Modified in-place.

Returns
-------
np.ndarray
"""
nqs = len(qs)

ncols = 1
if npy_vals.ndim == 2:
ncols = npy_vals.shape[0]
shaped_labels = np.broadcast_to(
labels_for_lexsort, (ncols, len(labels_for_lexsort))
)
else:
shaped_labels = labels_for_lexsort

npy_out = np.empty((ncols, ngroups, nqs), dtype=np.float64)

# Get an index of values sorted by values and then labels
order = (npy_vals, shaped_labels)
sort_arr = np.lexsort(order).astype(np.intp, copy=False)

func = partial(
libgroupby.group_quantile, labels=ids, qs=qs, interpolation=interpolation
)

if npy_vals.ndim == 1:
func(
npy_out[0],
values=npy_vals,
mask=mask.view(np.uint8),
sort_indexer=sort_arr,
result_mask=result_mask,
)
else:
# if we ever did get here with non-None result_mask, we'd pass result_mask[i]
assert result_mask is None
for i in range(ncols):
func(
npy_out[i],
values=npy_vals[i],
mask=mask[i].view(np.uint8),
sort_indexer=sort_arr[i],
)

if npy_vals.ndim == 1:
npy_out = npy_out.reshape(ngroups * nqs)
else:
npy_out = npy_out.reshape(ncols, ngroups * nqs)

return npy_out


def quantile_compat(
values: ArrayLike, qs: npt.NDArray[np.float64], interpolation: str
) -> ArrayLike:
Expand Down
65 changes: 64 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AstypeArg,
AxisInt,
Dtype,
DtypeObj,
FillnaOptions,
PositionalIndexer,
ScalarIndexer,
Expand All @@ -55,9 +56,13 @@

from pandas.core.dtypes.cast import maybe_cast_to_extension_array
from pandas.core.dtypes.common import (
is_bool_dtype,
is_datetime64_dtype,
is_dtype_equal,
is_float_dtype,
is_integer_dtype,
is_list_like,
is_object_dtype,
is_scalar,
is_timedelta64_dtype,
pandas_dtype,
Expand All @@ -82,7 +87,10 @@
rank,
unique,
)
from pandas.core.array_algos.quantile import quantile_with_mask
from pandas.core.array_algos.quantile import (
groupby_quantile_ndim_compat,
quantile_with_mask,
)
from pandas.core.sorting import (
nargminmax,
nargsort,
Expand Down Expand Up @@ -1688,6 +1696,61 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):

return arraylike.default_array_ufunc(self, ufunc, method, *inputs, **kwargs)

def _groupby_quantile(
self,
*,
qs: npt.NDArray[np.float64],
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
ngroups: int,
ids: npt.NDArray[np.intp],
labels_for_lexsort: npt.NDArray[np.intp],
):

mask = self.isna()

inference: DtypeObj | None = None

# TODO: 2023-01-26 we only have tests for the dt64/td64 cases here
if is_object_dtype(self.dtype):
raise TypeError("'quantile' cannot be performed against 'object' dtypes!")
elif is_integer_dtype(self.dtype):
npy_vals = self.to_numpy(dtype=float, na_value=np.nan)
inference = np.dtype(np.int64)
elif is_bool_dtype(self.dtype):
npy_vals = self.to_numpy(dtype=float, na_value=np.nan)
elif is_datetime64_dtype(self.dtype):
inference = self.dtype
npy_vals = np.asarray(self).astype(float)
elif is_timedelta64_dtype(self.dtype):
inference = self.dtype
npy_vals = np.asarray(self).astype(float)
elif is_float_dtype(self):
inference = np.dtype(np.float64)
npy_vals = self.to_numpy(dtype=float, na_value=np.nan)
else:
npy_vals = np.asarray(self)

npy_out = groupby_quantile_ndim_compat(
qs=qs,
interpolation=interpolation,
ngroups=ngroups,
ids=ids,
labels_for_lexsort=labels_for_lexsort,
npy_vals=npy_vals,
mask=np.asarray(mask),
result_mask=None,
)

if inference is not None:
# Check for edge case
if not (
is_integer_dtype(inference) and interpolation in {"linear", "midpoint"}
):
assert isinstance(inference, np.dtype) # for mypy
return npy_out.astype(inference)

return npy_out


class ExtensionArraySupportsAnyAll(ExtensionArray):
def any(self, *, skipna: bool = True) -> bool:
Expand Down
47 changes: 46 additions & 1 deletion pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@
masked_accumulations,
masked_reductions,
)
from pandas.core.array_algos.quantile import quantile_with_mask
from pandas.core.array_algos.quantile import (
groupby_quantile_ndim_compat,
quantile_with_mask,
)
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays import ExtensionArray
from pandas.core.construction import ensure_wrapped_if_datetimelike
Expand Down Expand Up @@ -1383,3 +1386,45 @@ def _accumulate(
data, mask = op(data, mask, skipna=skipna, **kwargs)

return type(self)(data, mask, copy=False)

# ------------------------------------------------------------------

@doc(ExtensionArray._groupby_quantile)
def _groupby_quantile(
self,
*,
qs: npt.NDArray[np.float64],
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
ngroups: int,
ids: npt.NDArray[np.intp],
labels_for_lexsort: npt.NDArray[np.intp],
):

nqs = len(qs)

mask = self._mask
result_mask = np.zeros((ngroups, nqs), dtype=np.bool_)

npy_vals = self.to_numpy(dtype=float, na_value=np.nan)

npy_out = groupby_quantile_ndim_compat(
qs=qs,
interpolation=interpolation,
ngroups=ngroups,
ids=ids,
labels_for_lexsort=labels_for_lexsort,
npy_vals=npy_vals,
mask=mask,
result_mask=result_mask,
)
result_mask = result_mask.reshape(ngroups * nqs)

if interpolation in {"linear", "midpoint"} and not is_float_dtype(self.dtype):
from pandas.core.arrays import FloatingArray

return FloatingArray(npy_out, result_mask)
else:
return type(self)(
npy_out.astype(self.dtype.numpy_dtype),
result_mask,
)
Loading