diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index cba635062f..6c3ac7537b 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2925,9 +2925,23 @@ def nunique(self) -> bigframes.series.Series: return bigframes.series.Series(block) def agg( - self, func: str | typing.Sequence[str] + self, + func: str + | typing.Sequence[str] + | typing.Mapping[blocks.Label, typing.Sequence[str] | str], ) -> DataFrame | bigframes.series.Series: - if utils.is_list_like(func): + if utils.is_dict_like(func): + # Must check dict-like first because dictionaries are list-like + # according to Pandas. + agg_cols = [] + for col_label, agg_func in func.items(): + agg_cols.append(self[col_label].agg(agg_func)) + + from bigframes.core.reshape import api as reshape + + return reshape.concat(agg_cols, axis=1) + + elif utils.is_list_like(func): aggregations = [agg_ops.lookup_agg_func(f) for f in func] for dtype, agg in itertools.product(self.dtypes, aggregations): @@ -2941,6 +2955,7 @@ def agg( aggregations, ) ) + else: return bigframes.series.Series( self._block.aggregate_all_and_stack( diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index fa451da35f..c80ced45a5 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -5652,3 +5652,29 @@ def test_astype_invalid_type_fail(scalars_dfs): with pytest.raises(TypeError, match=r".*Share your usecase with.*"): bf_df.astype(123) + + +def test_agg_with_dict(scalars_dfs): + bf_df, pd_df = scalars_dfs + agg_funcs = { + "int64_too": ["min", "max"], + "int64_col": ["min", "count"], + } + + bf_result = bf_df.agg(agg_funcs).to_pandas() + pd_result = pd_df.agg(agg_funcs) + + pd.testing.assert_frame_equal( + bf_result, pd_result, check_dtype=False, check_index_type=False + ) + + +def test_agg_with_dict_containing_non_existing_col_raise_key_error(scalars_dfs): + bf_df, _ = scalars_dfs + agg_funcs = { + "int64_too": ["min", "max"], + "nonexisting_col": ["count"], + } + + with pytest.raises(KeyError): + bf_df.agg(agg_funcs)