diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index fa7b945492d5d..c352a36bf6de1 100644 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -39,7 +39,7 @@ Backwards incompatible API changes .. _whatsnew_1000.api.other: -- +- :class:`pandas.core.groupby.GroupBy.transform` now raises on invalid operation names (:issue:`27489`). - Other API changes diff --git a/pandas/core/base.py b/pandas/core/base.py index a2691f66592e9..89a3d9cfea5ab 100644 --- a/pandas/core/base.py +++ b/pandas/core/base.py @@ -4,6 +4,7 @@ import builtins from collections import OrderedDict import textwrap +from typing import Optional import warnings import numpy as np @@ -566,7 +567,7 @@ def is_any_frame(): else: result = None - f = self._is_cython_func(arg) + f = self._get_cython_func(arg) if f and not args and not kwargs: return getattr(self, f)(), None @@ -653,7 +654,7 @@ def _shallow_copy(self, obj=None, obj_type=None, **kwargs): kwargs[attr] = getattr(self, attr) return obj_type(obj, **kwargs) - def _is_cython_func(self, arg): + def _get_cython_func(self, arg: str) -> Optional[str]: """ if we define an internal function for this argument, return it """ diff --git a/pandas/core/groupby/base.py b/pandas/core/groupby/base.py index 5c4f1fa3fbddf..fc3bb69afd0cb 100644 --- a/pandas/core/groupby/base.py +++ b/pandas/core/groupby/base.py @@ -98,6 +98,103 @@ def _gotitem(self, key, ndim, subset=None): dataframe_apply_whitelist = common_apply_whitelist | frozenset(["dtypes", "corrwith"]) -cython_transforms = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"]) +# cythonized transformations or canned "agg+broadcast", which do not +# require postprocessing of the result by transform. +cythonized_kernels = frozenset(["cumprod", "cumsum", "shift", "cummin", "cummax"]) cython_cast_blacklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"]) + +# List of aggregation/reduction functions. +# These map each group to a single numeric value +reduction_kernels = frozenset( + [ + "all", + "any", + "count", + "first", + "idxmax", + "idxmin", + "last", + "mad", + "max", + "mean", + "median", + "min", + "ngroup", + "nth", + "nunique", + "prod", + # as long as `quantile`'s signature accepts only + # a single quantile value, it's a reduction. + # GH#27526 might change that. + "quantile", + "sem", + "size", + "skew", + "std", + "sum", + "var", + ] +) + +# List of transformation functions. +# a transformation is a function that, for each group, +# produces a result that has the same shape as the group. +transformation_kernels = frozenset( + [ + "backfill", + "bfill", + "corrwith", + "cumcount", + "cummax", + "cummin", + "cumprod", + "cumsum", + "diff", + "ffill", + "fillna", + "pad", + "pct_change", + "rank", + "shift", + "tshift", + ] +) + +# these are all the public methods on Grouper which don't belong +# in either of the above lists +groupby_other_methods = frozenset( + [ + "agg", + "aggregate", + "apply", + "boxplot", + # corr and cov return ngroups*ncolumns rows, so they + # are neither a transformation nor a reduction + "corr", + "cov", + "describe", + "dtypes", + "expanding", + "filter", + "get_group", + "groups", + "head", + "hist", + "indices", + "ndim", + "ngroups", + "ohlc", + "pipe", + "plot", + "resample", + "rolling", + "tail", + "take", + "transform", + ] +) +# Valid values of `name` for `groupby.transform(name)` +# NOTE: do NOT edit this directly. New additions should be inserted +# into the appropriate list above. +transform_kernel_whitelist = reduction_kernels | transformation_kernels diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index b886b7e305ed0..1fef65349976b 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -573,13 +573,19 @@ def _transform_general(self, func, *args, **kwargs): def transform(self, func, *args, **kwargs): # optimized transforms - func = self._is_cython_func(func) or func + func = self._get_cython_func(func) or func + if isinstance(func, str): - if func in base.cython_transforms: - # cythonized transform + if not (func in base.transform_kernel_whitelist): + msg = "'{func}' is not a valid function name for transform(name)" + raise ValueError(msg.format(func=func)) + if func in base.cythonized_kernels: + # cythonized transformation or canned "reduction+broadcast" return getattr(self, func)(*args, **kwargs) else: - # cythonized aggregation and merge + # If func is a reduction, we need to broadcast the + # result to the whole group. Compute func result + # and deal with possible broadcasting below. result = getattr(self, func)(*args, **kwargs) else: return self._transform_general(func, *args, **kwargs) @@ -590,7 +596,7 @@ def transform(self, func, *args, **kwargs): obj = self._obj_with_exclusions - # nuiscance columns + # nuisance columns if not result.columns.equals(obj.columns): return self._transform_general(func, *args, **kwargs) @@ -853,7 +859,7 @@ def aggregate(self, func_or_funcs=None, *args, **kwargs): if relabeling: ret.columns = columns else: - cyfunc = self._is_cython_func(func_or_funcs) + cyfunc = self._get_cython_func(func_or_funcs) if cyfunc and not args and not kwargs: return getattr(self, cyfunc)() @@ -1005,15 +1011,19 @@ def _aggregate_named(self, func, *args, **kwargs): @Substitution(klass="Series", selected="A.") @Appender(_transform_template) def transform(self, func, *args, **kwargs): - func = self._is_cython_func(func) or func + func = self._get_cython_func(func) or func - # if string function if isinstance(func, str): - if func in base.cython_transforms: - # cythonized transform + if not (func in base.transform_kernel_whitelist): + msg = "'{func}' is not a valid function name for transform(name)" + raise ValueError(msg.format(func=func)) + if func in base.cythonized_kernels: + # cythonized transform or canned "agg+broadcast" return getattr(self, func)(*args, **kwargs) else: - # cythonized aggregation and merge + # If func is a reduction, we need to broadcast the + # result to the whole group. Compute func result + # and deal with possible broadcasting below. return self._transform_fast( lambda: getattr(self, func)(*args, **kwargs), func ) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 9aba9723e0546..3d4dbd3f8d887 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -261,7 +261,7 @@ class providing the base-class of operations. * f must return a value that either has the same shape as the input subframe or can be broadcast to the shape of the input subframe. - For example, f returns a scalar it will be broadcast to have the + For example, if `f` returns a scalar it will be broadcast to have the same shape as the input subframe. * if this is a DataFrame, f must support application column-by-column in the subframe. If f also supports application to the entire subframe, diff --git a/pandas/core/resample.py b/pandas/core/resample.py index fdf7cbd68d8cb..66878c3b1026c 100644 --- a/pandas/core/resample.py +++ b/pandas/core/resample.py @@ -1046,7 +1046,7 @@ def _downsample(self, how, **kwargs): **kwargs : kw args passed to how function """ self._set_binner() - how = self._is_cython_func(how) or how + how = self._get_cython_func(how) or how ax = self.ax obj = self._selected_obj @@ -1194,7 +1194,7 @@ def _downsample(self, how, **kwargs): if self.kind == "timestamp": return super()._downsample(how, **kwargs) - how = self._is_cython_func(how) or how + how = self._get_cython_func(how) or how ax = self.ax if is_subperiod(ax.freq, self.freq): diff --git a/pandas/tests/groupby/conftest.py b/pandas/tests/groupby/conftest.py index bdf93756b7559..72e60c5099304 100644 --- a/pandas/tests/groupby/conftest.py +++ b/pandas/tests/groupby/conftest.py @@ -2,6 +2,7 @@ import pytest from pandas import DataFrame, MultiIndex +from pandas.core.groupby.base import reduction_kernels from pandas.util import testing as tm @@ -102,3 +103,10 @@ def three_group(): "F": np.random.randn(11), } ) + + +@pytest.fixture(params=sorted(reduction_kernels)) +def reduction_func(request): + """yields the string names of all groupby reduction functions, one at a time. + """ + return request.param diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/test_transform.py index 9a8b7cf18f2c0..d3972e6ba9008 100644 --- a/pandas/tests/groupby/test_transform.py +++ b/pandas/tests/groupby/test_transform.py @@ -1003,6 +1003,55 @@ def test_ffill_not_in_axis(func, key, val): assert_frame_equal(result, expected) +def test_transform_invalid_name_raises(): + # GH#27486 + df = DataFrame(dict(a=[0, 1, 1, 2])) + g = df.groupby(["a", "b", "b", "c"]) + with pytest.raises(ValueError, match="not a valid function name"): + g.transform("some_arbitrary_name") + + # method exists on the object, but is not a valid transformation/agg + assert hasattr(g, "aggregate") # make sure the method exists + with pytest.raises(ValueError, match="not a valid function name"): + g.transform("aggregate") + + # Test SeriesGroupBy + g = df["a"].groupby(["a", "b", "b", "c"]) + with pytest.raises(ValueError, match="not a valid function name"): + g.transform("some_arbitrary_name") + + +@pytest.mark.parametrize( + "obj", + [ + DataFrame( + dict(a=[0, 0, 0, 1, 1, 1], b=range(6)), index=["A", "B", "C", "D", "E", "F"] + ), + Series([0, 0, 0, 1, 1, 1], index=["A", "B", "C", "D", "E", "F"]), + ], +) +def test_transform_agg_by_name(reduction_func, obj): + func = reduction_func + g = obj.groupby(np.repeat([0, 1], 3)) + + if func == "ngroup": # GH#27468 + pytest.xfail("TODO: g.transform('ngroup') doesn't work") + if func == "size": # GH#27469 + pytest.xfail("TODO: g.transform('size') doesn't work") + + args = {"nth": [0], "quantile": [0.5]}.get(func, []) + + result = g.transform(func, *args) + + # this is the *definition* of a transformation + tm.assert_index_equal(result.index, obj.index) + if hasattr(obj, "columns"): + tm.assert_index_equal(result.columns, obj.columns) + + # verify that values were broadcasted across each group + assert len(set(DataFrame(result).iloc[-3:, -1])) == 1 + + def test_transform_lambda_with_datetimetz(): # GH 27496 df = DataFrame( diff --git a/pandas/tests/groupby/test_whitelist.py b/pandas/tests/groupby/test_whitelist.py index ee380c6108c38..05d745ccc0e8e 100644 --- a/pandas/tests/groupby/test_whitelist.py +++ b/pandas/tests/groupby/test_whitelist.py @@ -9,6 +9,11 @@ import pytest from pandas import DataFrame, Index, MultiIndex, Series, date_range +from pandas.core.groupby.base import ( + groupby_other_methods, + reduction_kernels, + transformation_kernels, +) from pandas.util import testing as tm AGG_FUNCTIONS = [ @@ -376,3 +381,49 @@ def test_groupby_selection_with_methods(df): tm.assert_frame_equal( g.filter(lambda x: len(x) == 3), g_exp.filter(lambda x: len(x) == 3) ) + + +def test_all_methods_categorized(mframe): + grp = mframe.groupby(mframe.iloc[:, 0]) + names = {_ for _ in dir(grp) if not _.startswith("_")} - set(mframe.columns) + new_names = set(names) + new_names -= reduction_kernels + new_names -= transformation_kernels + new_names -= groupby_other_methods + + assert not (reduction_kernels & transformation_kernels) + assert not (reduction_kernels & groupby_other_methods) + assert not (transformation_kernels & groupby_other_methods) + + # new public method? + if new_names: + msg = """ +There are uncatgeorized methods defined on the Grouper class: +{names}. + +Was a new method recently added? + +Every public method On Grouper must appear in exactly one the +following three lists defined in pandas.core.groupby.base: +- `reduction_kernels` +- `transformation_kernels` +- `groupby_other_methods` +see the comments in pandas/core/groupby/base.py for guidance on +how to fix this test. + """ + raise AssertionError(msg.format(names=names)) + + # removed a public method? + all_categorized = reduction_kernels | transformation_kernels | groupby_other_methods + print(names) + print(all_categorized) + if not (names == all_categorized): + msg = """ +Some methods which are supposed to be on the Grouper class +are missing: +{names}. + +They're still defined in one of the lists that live in pandas/core/groupby/base.py. +If you removed a method, you should update them +""" + raise AssertionError(msg.format(names=all_categorized - names))