diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index caf5ac5928f..e5030eb634b 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -1308,12 +1308,9 @@ def _jit_groupby_apply( chunk_results = jit_groupby_apply( offsets, grouped_values, function, *args ) - result = cudf.Series._from_data( - {None: chunk_results}, index=group_names + return self._post_process_chunk_results( + chunk_results, group_names, group_keys, grouped_values ) - result.index.names = self.grouping.names - - return result @_cudf_nvtx_annotate def _iterative_groupby_apply( @@ -1341,12 +1338,15 @@ def _post_process_chunk_results( ): if not len(chunk_results): return self.obj.head(0) - if cudf.api.types.is_scalar(chunk_results[0]): - result = cudf.Series._from_data( - {None: chunk_results}, index=group_names - ) + if isinstance(chunk_results, ColumnBase) or cudf.api.types.is_scalar( + chunk_results[0] + ): + data = {None: chunk_results} + ty = cudf.Series if self._as_index else cudf.DataFrame + result = ty._from_data(data, index=group_names) result.index.names = self.grouping.names return result + elif isinstance(chunk_results[0], cudf.Series) and isinstance( self.obj, cudf.DataFrame ): @@ -1380,6 +1380,10 @@ def _post_process_chunk_results( index_data = group_keys._data.copy(deep=True) index_data[None] = grouped_values.index._column result.index = cudf.MultiIndex._from_data(index_data) + elif len(chunk_results) == len(group_names): + result = cudf.concat(chunk_results, axis=1).T + result.index = group_names + result.index.names = self.grouping.names else: raise TypeError( "Error handling Groupby apply output with input of " @@ -1552,7 +1556,6 @@ def mult(df): result = result.sort_index() if self._as_index is False: result = result.reset_index() - result[None] = result.pop(0) return result @_cudf_nvtx_annotate diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index bc2aaab1286..8dbd74f4edf 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -871,6 +871,33 @@ def test_groupby_apply_return_df(func): assert_groupby_results_equal(expect, got) +@pytest.mark.parametrize("as_index", [True, False]) +def test_groupby_apply_return_reindexed_series(as_index): + def gdf_func(df): + return cudf.Series([df["a"].sum(), df["b"].min(), df["c"].max()]) + + def pdf_func(df): + return pd.Series([df["a"].sum(), df["b"].min(), df["c"].max()]) + + df = cudf.DataFrame( + { + "key": [0, 0, 1, 1, 2, 2], + "a": [1, 2, 3, 4, 5, 6], + "b": [7, 8, 9, 10, 11, 12], + "c": [13, 14, 15, 16, 17, 18], + } + ) + pdf = df.to_pandas() + + kwargs = {} + if PANDAS_GE_220: + kwargs["include_groups"] = False + + expect = pdf.groupby("key", as_index=as_index).apply(pdf_func, **kwargs) + got = df.groupby("key", as_index=as_index).apply(gdf_func, **kwargs) + assert_groupby_results_equal(expect, got) + + @pytest.mark.parametrize("nelem", [2, 3, 100, 500, 1000]) @pytest.mark.parametrize( "func",