Skip to content

Commit 4191821

Browse files
feat: Support len() on Groupby objects (#2183)
1 parent c331dfe commit 4191821

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

bigframes/core/groupby/dataframe_group_by.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ def __iter__(self) -> Iterable[Tuple[blocks.Label, df.DataFrame]]:
177177
filtered_df = df.DataFrame(filtered_block)
178178
yield group_keys, filtered_df
179179

180+
def __len__(self) -> int:
181+
return len(self.agg([]))
182+
180183
def size(self) -> typing.Union[df.DataFrame, series.Series]:
181184
agg_block, _ = self._block.aggregate_size(
182185
by_column_ids=self._by_col_ids,

bigframes/core/groupby/series_group_by.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def __iter__(self) -> Iterable[Tuple[blocks.Label, series.Series]]:
108108
filtered_series.name = self._value_name
109109
yield group_keys, filtered_series
110110

111+
def __len__(self) -> int:
112+
return len(self.agg([]))
113+
111114
def all(self) -> series.Series:
112115
return self._aggregate(agg_ops.all_op)
113116

@@ -275,9 +278,9 @@ def agg(self, func=None) -> typing.Union[df.DataFrame, series.Series]:
275278
if column_names:
276279
agg_block = agg_block.with_column_labels(column_names)
277280

278-
if len(aggregations) > 1:
279-
return df.DataFrame(agg_block)
280-
return series.Series(agg_block)
281+
if len(aggregations) == 1:
282+
return series.Series(agg_block)
283+
return df.DataFrame(agg_block)
281284

282285
aggregate = agg
283286

tests/system/small/test_groupby.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ def test_dataframe_groupby_head(scalars_df_index, scalars_pandas_df_index):
6161
pd.testing.assert_frame_equal(pd_result, bf_result, check_dtype=False)
6262

6363

64+
def test_dataframe_groupby_len(scalars_df_index, scalars_pandas_df_index):
65+
col_names = ["int64_too", "float64_col", "int64_col", "bool_col", "string_col"]
66+
67+
bf_result = len(scalars_df_index[col_names].groupby("bool_col"))
68+
pd_result = len(scalars_pandas_df_index[col_names].groupby("bool_col"))
69+
70+
assert bf_result == pd_result
71+
72+
6473
def test_dataframe_groupby_median(scalars_df_index, scalars_pandas_df_index):
6574
col_names = ["int64_too", "float64_col", "int64_col", "bool_col", "string_col"]
6675
bf_result = (
@@ -668,6 +677,13 @@ def test_dataframe_groupby_last(
668677
# ==============
669678

670679

680+
def test_series_groupby_len(scalars_df_index, scalars_pandas_df_index):
681+
bf_result = len(scalars_df_index.groupby("bool_col")["int64_col"])
682+
pd_result = len(scalars_pandas_df_index.groupby("bool_col")["int64_col"])
683+
684+
assert bf_result == pd_result
685+
686+
671687
@pytest.mark.parametrize(
672688
("agg"),
673689
[

0 commit comments

Comments
 (0)