Skip to content

Commit

Permalink
sdks/python: enable named aggregation to deferred DataFrame groupby (#…
Browse files Browse the repository at this point in the history
…33672)

* sdks/python: enable named aggregation

In this commit, we enable named aggregated to
deferred DataFrame groupby operation.

* sdks/python: test named aggregation

In this commit, we test named aggregation to deferred
DataFrame groupby operation by passing None as function
and passing multiple kwargs aggregation functions on testing
columns.
  • Loading branch information
mohamedawnallah authored Feb 3, 2025
1 parent d63b8fb commit df13ffe
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 15 deletions.
95 changes: 80 additions & 15 deletions sdks/python/apache_beam/dataframe/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -4232,22 +4232,11 @@ def __getitem__(self, name):
projection=name)

@frame_base.with_docs_from(DataFrameGroupBy)
def agg(self, fn, *args, **kwargs):
if _is_associative(fn):
return _liftable_agg(fn)(self, *args, **kwargs)
elif _is_liftable_with_sum(fn):
return _liftable_agg(fn, postagg_meth='sum')(self, *args, **kwargs)
elif _is_unliftable(fn):
return _unliftable_agg(fn)(self, *args, **kwargs)
elif callable(fn):
return DeferredDataFrame(
expressions.ComputedExpression(
'agg',
lambda gb: gb.agg(fn, *args, **kwargs), [self._expr],
requires_partition_by=partitionings.Index(),
preserves_partition_by=partitionings.Singleton()))
def agg(self, fn=None, *args, **kwargs):
if fn is None:
return _agg_with_no_function(self, *args, **kwargs)
else:
raise NotImplementedError(f"GroupBy.agg(func={fn!r})")
return _handle_agg_function(self, fn, "agg", *args, **kwargs)

@property
def ndim(self):
Expand Down Expand Up @@ -4696,6 +4685,82 @@ def _check_str_or_np_builtin(agg_func, func_list):
getattr(agg_func, '__name__', None) in func_list
and agg_func.__module__ in ('numpy', 'builtins'))

def _agg_with_no_function(gb, *args, **kwargs):
"""
Applies aggregation functions to the grouped data based on the provided
keyword arguments and combines the results into a single DataFrame.
Args:
gb: The groupby instance (DeferredGroupBy).
*args: Additional positional arguments passed to the aggregation funcs.
**kwargs: A dictionary where each key is the column name to aggregate,
the value is a tuple containing the input column name and
the aggregation function to apply.
Returns:
DeferredDataFrame: A DataFrame that contains the aggregated results of
all specified columns.
Raises:
ValueError: If no aggregation functions are provided in the `kwargs`.
NotImplementedError: If the aggregation function type is unsupported.
"""
if not kwargs:
raise ValueError("No aggregation functions specified")

# Handle dictionary-like input for aggregation.
result_columns, result_frames = [], []
for col_name, (input_col, agg_fn) in kwargs.items():
frame = _handle_agg_function(
gb[input_col], agg_fn, f"agg_{col_name}", *args
)
result_frames.append(frame)
result_columns.append(col_name)

# Combine all the resulting DeferredDataFrames into a single DataFrame.
return DeferredDataFrame(
expressions.ComputedExpression(
"agg",
lambda *results: pd.concat(results, axis=1, keys=result_columns),
[frame._expr for frame in result_frames],
requires_partition_by=partitionings.Index(),
preserves_partition_by=partitionings.Singleton(),
)
)

def _handle_agg_function(gb, agg_func, agg_name, *args, **kwargs):
"""
Handles the aggregation logic based on the function type passed.
Args:
gb: The groupby instance (DeferredGroupBy).
agg_name: The name/label of the aggregation function.
fn: The aggregation function to apply.
*args: Additional arguments to pass to the aggregation function.
**kwargs: Keyword arguments to pass to the aggregation function.
Returns:
A DeferredDataFrame or the result of the aggregation function.
Raises:
NotImplementedError: If the aggregation function type is unsupported.
"""
if _is_associative(agg_func):
return _liftable_agg(agg_func)(gb, *args, **kwargs)
elif _is_liftable_with_sum(agg_func):
return _liftable_agg(agg_func, postagg_meth='sum')(gb, *args, **kwargs)
elif _is_unliftable(agg_func):
return _unliftable_agg(agg_func)(gb, *args, **kwargs)
elif callable(agg_func):
return DeferredDataFrame(
expressions.ComputedExpression(
agg_name,
lambda gb_val: gb_val.agg(agg_func, *args, **kwargs),
[gb._expr],
requires_partition_by=partitionings.Index(),
preserves_partition_by=partitionings.Singleton()))
else:
raise NotImplementedError(f"GroupBy.agg(func={agg_func!r})")

def _is_associative(agg_func):
return _check_str_or_np_builtin(agg_func, LIFTABLE_AGGREGATIONS)
Expand Down
10 changes: 10 additions & 0 deletions sdks/python/apache_beam/dataframe/frames_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,6 +2364,16 @@ def test_std_all_na(self):
self._run_test(lambda s: s.agg('std'), s)
self._run_test(lambda s: s.std(), s)

def test_df_agg_operations_on_columns(self):
self._run_test(
lambda df: df.groupby('group').agg(
mean_foo=('foo', lambda x: np.mean(x)),
median_bar=('bar', lambda x: np.median(x)),
sum_baz=('baz', 'sum'),
count_bool=('bool', 'count'),
),
GROUPBY_DF)

def test_std_mostly_na_with_ddof(self):
df = pd.DataFrame({
'one': [i if i % 8 == 0 else np.nan for i in range(8)],
Expand Down

0 comments on commit df13ffe

Please sign in to comment.