diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index b10f201e79318..e305df7d525fa 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -1304,6 +1304,7 @@ Groupby/resample/rolling - Bug in :meth:`.DataFrameGroupBy.transform` and :meth:`.SeriesGroupBy.transform` would raise incorrectly when grouper had ``axis=1`` for ``"ngroup"`` argument (:issue:`45986`) - Bug in :meth:`.DataFrameGroupBy.describe` produced incorrect results when data had duplicate columns (:issue:`50806`) - Bug in :meth:`.DataFrameGroupBy.agg` with ``engine="numba"`` failing to respect ``as_index=False`` (:issue:`51228`) +- Bug in :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, and :meth:`Resampler.agg` would ignore arguments when passed a list of functions (:issue:`50863`) - Reshaping @@ -1317,6 +1318,7 @@ Reshaping - Clarified error message in :func:`merge` when passing invalid ``validate`` option (:issue:`49417`) - Bug in :meth:`DataFrame.explode` raising ``ValueError`` on multiple columns with ``NaN`` values or empty lists (:issue:`46084`) - Bug in :meth:`DataFrame.transpose` with ``IntervalDtype`` column with ``timedelta64[ns]`` endpoints (:issue:`44917`) +- Bug in :meth:`DataFrame.agg` and :meth:`Series.agg` would ignore arguments when passed a list of functions (:issue:`50863`) - Sparse diff --git a/pandas/core/apply.py b/pandas/core/apply.py index c28da1bc758cd..f29a6ce4c0b82 100644 --- a/pandas/core/apply.py +++ b/pandas/core/apply.py @@ -332,19 +332,28 @@ def agg_list_like(self) -> DataFrame | Series: for a in arg: colg = obj._gotitem(selected_obj.name, ndim=1, subset=selected_obj) - new_res = colg.aggregate(a) + if isinstance(colg, (ABCSeries, ABCDataFrame)): + new_res = colg.aggregate( + a, self.axis, *self.args, **self.kwargs + ) + else: + new_res = colg.aggregate(a, *self.args, **self.kwargs) results.append(new_res) # make sure we find a good name name = com.get_callable_name(a) or a keys.append(name) - # multiples else: indices = [] for index, col in enumerate(selected_obj): colg = obj._gotitem(col, ndim=1, subset=selected_obj.iloc[:, index]) - new_res = colg.aggregate(arg) + if isinstance(colg, (ABCSeries, ABCDataFrame)): + new_res = colg.aggregate( + arg, self.axis, *self.args, **self.kwargs + ) + else: + new_res = colg.aggregate(arg, *self.args, **self.kwargs) results.append(new_res) indices.append(index) keys = selected_obj.columns.take(indices) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 7745de87633eb..244afa61701d8 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -243,7 +243,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) # Catch instances of lists / tuples # but not the class list / tuple itself. func = maybe_mangle_lambdas(func) - ret = self._aggregate_multiple_funcs(func) + ret = self._aggregate_multiple_funcs(func, *args, **kwargs) if relabeling: # columns is not narrowed by mypy from relabeling flag assert columns is not None # for mypy @@ -275,7 +275,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs) agg = aggregate - def _aggregate_multiple_funcs(self, arg) -> DataFrame: + def _aggregate_multiple_funcs(self, arg, *args, **kwargs) -> DataFrame: if isinstance(arg, dict): if self.as_index: # GH 15931 @@ -300,7 +300,7 @@ def _aggregate_multiple_funcs(self, arg) -> DataFrame: for idx, (name, func) in enumerate(arg): key = base.OutputKey(label=name, position=idx) - results[key] = self.aggregate(func) + results[key] = self.aggregate(func, *args, **kwargs) if any(isinstance(x, DataFrame) for x in results.values()): from pandas import concat diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py index 7aaad4d2ad081..a0e667dc8f243 100644 --- a/pandas/tests/apply/test_frame_apply.py +++ b/pandas/tests/apply/test_frame_apply.py @@ -1623,3 +1623,25 @@ def test_any_apply_keyword_non_zero_axis_regression(): result = df.apply("any", 1) tm.assert_series_equal(result, expected) + + +def test_agg_list_like_func_with_args(): + # GH 50624 + df = DataFrame({"x": [1, 2, 3]}) + + def foo1(x, a=1, c=0): + return x + a + c + + def foo2(x, b=2, c=0): + return x + b + c + + msg = r"foo1\(\) got an unexpected keyword argument 'b'" + with pytest.raises(TypeError, match=msg): + df.agg([foo1, foo2], 0, 3, b=3, c=4) + + result = df.agg([foo1, foo2], 0, 3, c=4) + expected = DataFrame( + [[8, 8], [9, 9], [10, 10]], + columns=MultiIndex.from_tuples([("x", "foo1"), ("x", "foo2")]), + ) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/apply/test_series_apply.py b/pandas/tests/apply/test_series_apply.py index 53dee6e15c3e0..30f040b4197eb 100644 --- a/pandas/tests/apply/test_series_apply.py +++ b/pandas/tests/apply/test_series_apply.py @@ -107,6 +107,26 @@ def f(x, a=0, b=0, c=0): tm.assert_series_equal(result, expected) +def test_agg_list_like_func_with_args(): + # GH 50624 + + s = Series([1, 2, 3]) + + def foo1(x, a=1, c=0): + return x + a + c + + def foo2(x, b=2, c=0): + return x + b + c + + msg = r"foo1\(\) got an unexpected keyword argument 'b'" + with pytest.raises(TypeError, match=msg): + s.agg([foo1, foo2], 0, 3, b=3, c=4) + + result = s.agg([foo1, foo2], 0, 3, c=4) + expected = DataFrame({"foo1": [8, 9, 10], "foo2": [8, 9, 10]}) + tm.assert_frame_equal(result, expected) + + def test_series_map_box_timestamps(): # GH#2689, GH#2627 ser = Series(pd.date_range("1/1/2000", periods=10)) diff --git a/pandas/tests/groupby/aggregate/test_aggregate.py b/pandas/tests/groupby/aggregate/test_aggregate.py index e7be78be55620..22c9bbd74395d 100644 --- a/pandas/tests/groupby/aggregate/test_aggregate.py +++ b/pandas/tests/groupby/aggregate/test_aggregate.py @@ -1468,3 +1468,49 @@ def test_agg_of_mode_list(test, constant): expected = expected.set_index(0) tm.assert_frame_equal(result, expected) + + +def test__dataframe_groupy_agg_list_like_func_with_args(): + # GH 50624 + df = DataFrame({"x": [1, 2, 3], "y": ["a", "b", "c"]}) + gb = df.groupby("y") + + def foo1(x, a=1, c=0): + return x.sum() + a + c + + def foo2(x, b=2, c=0): + return x.sum() + b + c + + msg = r"foo1\(\) got an unexpected keyword argument 'b'" + with pytest.raises(TypeError, match=msg): + gb.agg([foo1, foo2], 3, b=3, c=4) + + result = gb.agg([foo1, foo2], 3, c=4) + expected = DataFrame( + [[8, 8], [9, 9], [10, 10]], + index=Index(["a", "b", "c"], name="y"), + columns=MultiIndex.from_tuples([("x", "foo1"), ("x", "foo2")]), + ) + tm.assert_frame_equal(result, expected) + + +def test__series_groupy_agg_list_like_func_with_args(): + # GH 50624 + s = Series([1, 2, 3]) + sgb = s.groupby(s) + + def foo1(x, a=1, c=0): + return x.sum() + a + c + + def foo2(x, b=2, c=0): + return x.sum() + b + c + + msg = r"foo1\(\) got an unexpected keyword argument 'b'" + with pytest.raises(TypeError, match=msg): + sgb.agg([foo1, foo2], 3, b=3, c=4) + + result = sgb.agg([foo1, foo2], 3, c=4) + expected = DataFrame( + [[8, 8], [9, 9], [10, 10]], index=Index([1, 2, 3]), columns=["foo1", "foo2"] + ) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/resample/test_resample_api.py b/pandas/tests/resample/test_resample_api.py index e6e924793389d..0b8dc8f3e8ac4 100644 --- a/pandas/tests/resample/test_resample_api.py +++ b/pandas/tests/resample/test_resample_api.py @@ -633,6 +633,31 @@ def test_try_aggregate_non_existing_column(): df.resample("30T").agg({"x": ["mean"], "y": ["median"], "z": ["sum"]}) +def test_agg_list_like_func_with_args(): + # 50624 + df = DataFrame( + {"x": [1, 2, 3]}, index=date_range("2020-01-01", periods=3, freq="D") + ) + + def foo1(x, a=1, c=0): + return x + a + c + + def foo2(x, b=2, c=0): + return x + b + c + + msg = r"foo1\(\) got an unexpected keyword argument 'b'" + with pytest.raises(TypeError, match=msg): + df.resample("D").agg([foo1, foo2], 3, b=3, c=4) + + result = df.resample("D").agg([foo1, foo2], 3, c=4) + expected = DataFrame( + [[8, 8], [9, 9], [10, 10]], + index=date_range("2020-01-01", periods=3, freq="D"), + columns=pd.MultiIndex.from_tuples([("x", "foo1"), ("x", "foo2")]), + ) + tm.assert_frame_equal(result, expected) + + def test_selection_api_validation(): # GH 13500 index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D")