diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index ac45222625569..e8010f1216dd4 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -34,7 +34,7 @@ class providing the base-class of operations. from pandas._config.config import option_context -from pandas._libs import Timestamp +from pandas._libs import Timestamp, lib import pandas._libs.groupby as libgroupby from pandas._typing import F, FrameOrSeries, FrameOrSeriesUnion, Scalar from pandas.compat.numpy import function as nv @@ -61,11 +61,11 @@ class providing the base-class of operations. import pandas.core.common as com from pandas.core.frame import DataFrame from pandas.core.generic import NDFrame -from pandas.core.groupby import base, ops +from pandas.core.groupby import base, numba_, ops from pandas.core.indexes.api import CategoricalIndex, Index, MultiIndex from pandas.core.series import Series from pandas.core.sorting import get_group_index_sorter -from pandas.core.util.numba_ import maybe_use_numba +from pandas.core.util.numba_ import NUMBA_FUNC_CACHE, maybe_use_numba _common_see_also = """ See Also @@ -827,7 +827,12 @@ def __iter__(self): input="dataframe", examples=_apply_docs["dataframe_examples"] ) ) - def apply(self, func, *args, **kwargs): + def apply(self, func, *args, engine=None, engine_kwargs=None, **kwargs): + + if maybe_use_numba(engine): + return self._apply_with_numba( + func, *args, engine_kwargs=engine_kwargs, **kwargs + ) func = self._is_builtin_func(func) @@ -871,6 +876,35 @@ def f(g): return result + def _apply_with_numba(self, func, *args, engine_kwargs=None, **kwargs): + group_keys = self.grouper._get_group_keys() + + with _group_selection_context(self): + # We always drop the column with the groupby key + data = self._selected_obj + labels, _, n_groups = self.grouper.group_info + sorted_index = get_group_index_sorter(labels, n_groups) + sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False) + sorted_data = data.take(sorted_index, axis=self.axis) + starts, ends = lib.generate_slices(sorted_labels, n_groups) + cache_key = (func, "groupby_apply") + if cache_key in NUMBA_FUNC_CACHE: + # Return an already compiled version of roll_apply if available + apply_func = NUMBA_FUNC_CACHE[cache_key] + else: + apply_func = numba_.generate_numba_apply_func( + tuple(args), kwargs, func, engine_kwargs + ) + result = apply_func( + sorted_data.to_numpy(), starts, ends, len(group_keys), len(data.columns) + ) + + if self.grouper.nkeys > 1: + index = MultiIndex.from_tuples(group_keys, names=self.grouper.names) + else: + index = Index(group_keys, name=self.grouper.names[0]) + return self.obj._constructor(result, index=index, columns=data.columns) + def _python_apply_general( self, f: F, data: FrameOrSeriesUnion ) -> FrameOrSeriesUnion: diff --git a/pandas/core/groupby/numba_.py b/pandas/core/groupby/numba_.py new file mode 100644 index 0000000000000..6ba3659985f6b --- /dev/null +++ b/pandas/core/groupby/numba_.py @@ -0,0 +1,73 @@ +from typing import Any, Callable, Dict, Optional, Tuple + +import numpy as np + +from pandas._typing import Scalar +from pandas.compat._optional import import_optional_dependency + +from pandas.core.util.numba_ import ( + check_kwargs_and_nopython, + get_jit_arguments, + jit_user_function, +) + + +def generate_numba_apply_func( + args: Tuple, + kwargs: Dict[str, Any], + func: Callable[..., Scalar], + engine_kwargs: Optional[Dict[str, bool]], +): + """ + Generate a numba jitted apply function specified by values from engine_kwargs. + + 1. jit the user's function + 2. Return a rolling apply function with the jitted function inline + + Configurations specified in engine_kwargs apply to both the user's + function _AND_ the rolling apply function. + + Parameters + ---------- + args : tuple + *args to be passed into the function + kwargs : dict + **kwargs to be passed into the function + func : function + function to be applied to each window and will be JITed + engine_kwargs : dict + dictionary of arguments to be passed into numba.jit + + Returns + ------- + Numba function + """ + nopython, nogil, parallel = get_jit_arguments(engine_kwargs) + + check_kwargs_and_nopython(kwargs, nopython) + + numba_func = jit_user_function(func, nopython, nogil, parallel) + + numba = import_optional_dependency("numba") + + if parallel: + loop_range = numba.prange + else: + loop_range = range + + @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) + def group_apply( + values: np.ndarray, + begin: np.ndarray, + end: np.ndarray, + num_groups: int, + num_columns: int, + ) -> np.ndarray: + result = np.empty((num_groups, num_columns)) + for i in loop_range(num_groups): + for j in loop_range(num_columns): + group = values[begin[i] : end[i], j] + result[i, j] = numba_func(group, *args) + return result + + return group_apply