Skip to content

BUG: groupby.transform(name) validates name is an aggregation #27597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Performance improvements
Bug fixes
~~~~~~~~~

- Previously, when :class:`pandas.core.groupby.GroupBy.transform` was passed the name of a transformation (e.g. `rank`, `ffill`, etc) it broadcast the results resulting in erroneous values. Only reduction results and not the results of transformations. The return value in these cases is the same as calling the named function directly (e.g. `g.rank()`) (:issue:`14274`) (:issue:`19354`) (:issue:`22509`).

Categorical
^^^^^^^^^^^
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/groupby/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def _gotitem(self, key, ndim, subset=None):
[
"backfill",
"bfill",
"corrwith",
"cumcount",
"cummax",
"cummin",
Expand Down Expand Up @@ -173,6 +172,8 @@ def _gotitem(self, key, ndim, subset=None):
# are neither a transformation nor a reduction
"corr",
"cov",
# corrwith does not preserve shape, depending on `other`
"corrwith",
"describe",
"dtypes",
"expanding",
Expand All @@ -197,4 +198,4 @@ def _gotitem(self, key, ndim, subset=None):
# Valid values of `name` for `groupby.transform(name)`
# NOTE: do NOT edit this directly. New additions should be inserted
# into the appropriate list above.
transform_kernel_whitelist = reduction_kernels | transformation_kernels
groupby_transform_whitelist = reduction_kernels | transformation_kernels
15 changes: 11 additions & 4 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def transform(self, func, *args, **kwargs):
func = self._get_cython_func(func) or func

if isinstance(func, str):
if not (func in base.transform_kernel_whitelist):
if not (func in base.groupby_transform_whitelist):
msg = "'{func}' is not a valid function name for transform(name)"
raise ValueError(msg.format(func=func))
if func in base.cythonized_kernels:
Expand Down Expand Up @@ -615,7 +615,11 @@ def _transform_fast(self, result, obj, func_nm):
ids, _, ngroup = self.grouper.group_info
output = []
for i, _ in enumerate(result.columns):
res = algorithms.take_1d(result.iloc[:, i].values, ids)
if func_nm in base.reduction_kernels:
# only broadcast results if we performed a reduction
res = algorithms.take_1d(result.iloc[:, i]._values, ids)
else:
res = result.iloc[:, i].values
if cast:
res = self._try_cast(res, obj.iloc[:, i])
output.append(res)
Expand Down Expand Up @@ -1014,7 +1018,7 @@ def transform(self, func, *args, **kwargs):
func = self._get_cython_func(func) or func

if isinstance(func, str):
if not (func in base.transform_kernel_whitelist):
if not (func in base.groupby_transform_whitelist):
msg = "'{func}' is not a valid function name for transform(name)"
raise ValueError(msg.format(func=func))
if func in base.cythonized_kernels:
Expand Down Expand Up @@ -1072,7 +1076,10 @@ def _transform_fast(self, func, func_nm):

ids, _, ngroup = self.grouper.group_info
cast = self._transform_should_cast(func_nm)
out = algorithms.take_1d(func()._values, ids)
out = func()
if func_nm in base.reduction_kernels:
# only broadcast results if we performed a reduction
out = algorithms.take_1d(out._values, ids)
if cast:
out = self._try_cast(out, self.obj)
return Series(out, index=self.obj.index, name=self.obj.name)
Expand Down
9 changes: 8 additions & 1 deletion pandas/tests/groupby/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from pandas import DataFrame, MultiIndex
from pandas.core.groupby.base import reduction_kernels
from pandas.core.groupby.base import reduction_kernels, transformation_kernels
from pandas.util import testing as tm


Expand Down Expand Up @@ -105,6 +105,13 @@ def three_group():
)


@pytest.fixture(params=sorted(transformation_kernels))
def transformation_func(request):
"""yields the string names of all groupby reduction functions, one at a time.
"""
return request.param


@pytest.fixture(params=sorted(reduction_kernels))
def reduction_func(request):
"""yields the string names of all groupby reduction functions, one at a time.
Expand Down
30 changes: 30 additions & 0 deletions pandas/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,36 @@ def test_transform_agg_by_name(reduction_func, obj):
assert len(set(DataFrame(result).iloc[-3:, -1])) == 1


@pytest.mark.parametrize(
"obj",
[
DataFrame(
dict(a=[0, 0, 0, 1, 1, 1], b=range(6)), index=["A", "B", "C", "D", "E", "F"]
),
Series([0, 0, 0, 1, 1, 1], index=["A", "B", "C", "D", "E", "F"]),
],
)
def test_transform_transformation_by_name(transformation_func, obj):
func = transformation_func
g = obj.groupby(np.repeat([0, 1], 3))

if func.startswith("cum"): # BUG
pytest.xfail("skip cum* function tests")
if func == "tshift": # BUG
pytest.xfail("tshift")

args = {"fillna": [0]}.get(func, [])

result = g.transform(func, *args)
# for transformations, g.transform.(name) should return the same result
# as `g.name()`
expected = getattr(g, func)(*args)
if isinstance(obj, DataFrame):
tm.assert_frame_equal(result, expected)
else:
tm.assert_series_equal(result, expected)


def test_transform_lambda_with_datetimetz():
# GH 27496
df = DataFrame(
Expand Down