Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 26 additions & 82 deletions python/pyspark/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,21 +468,10 @@ def first(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fr
if not isinstance(min_count, int):
raise TypeError("min_count must be integer")

if min_count > 0:

def first(col: Column) -> Column:
return F.when(
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
).otherwise(F.first(col, ignorenulls=True))

else:

def first(col: Column) -> Column:
return F.first(col, ignorenulls=True)

return self._reduce_for_stat_function(
first,
lambda col: F.first(col, ignorenulls=True),
accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
min_count=min_count,
)

def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike:
Expand Down Expand Up @@ -549,21 +538,10 @@ def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fra
if not isinstance(min_count, int):
raise TypeError("min_count must be integer")

if min_count > 0:

def last(col: Column) -> Column:
return F.when(
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
).otherwise(F.last(col, ignorenulls=True))

else:

def last(col: Column) -> Column:
return F.last(col, ignorenulls=True)

return self._reduce_for_stat_function(
last,
lambda col: F.last(col, ignorenulls=True),
accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
min_count=min_count,
)

def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike:
Expand Down Expand Up @@ -624,20 +602,10 @@ def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram
if not isinstance(min_count, int):
raise TypeError("min_count must be integer")

if min_count > 0:

def max(col: Column) -> Column:
return F.when(
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
).otherwise(F.max(col))

else:

def max(col: Column) -> Column:
return F.max(col)

return self._reduce_for_stat_function(
max, accepted_spark_types=(NumericType, BooleanType) if numeric_only else None
F.max,
accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
min_count=min_count,
)

def mean(self, numeric_only: Optional[bool] = True) -> FrameLike:
Expand Down Expand Up @@ -802,20 +770,10 @@ def min(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram
if not isinstance(min_count, int):
raise TypeError("min_count must be integer")

if min_count > 0:

def min(col: Column) -> Column:
return F.when(
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
).otherwise(F.min(col))

else:

def min(col: Column) -> Column:
return F.min(col)

return self._reduce_for_stat_function(
min, accepted_spark_types=(NumericType, BooleanType) if numeric_only else None
F.min,
accepted_spark_types=(NumericType, BooleanType) if numeric_only else None,
min_count=min_count,
)

# TODO: sync the doc.
Expand Down Expand Up @@ -944,20 +902,11 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL
f"numeric_only=False, skip unsupported columns: {unsupported}"
)

if min_count > 0:

def sum(col: Column) -> Column:
return F.when(
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
).otherwise(F.sum(col))

else:

def sum(col: Column) -> Column:
return F.sum(col)

return self._reduce_for_stat_function(
sum, accepted_spark_types=(NumericType,), bool_to_numeric=True
F.sum,
accepted_spark_types=(NumericType, BooleanType),
bool_to_numeric=True,
min_count=min_count,
)

# TODO: sync the doc.
Expand Down Expand Up @@ -1324,22 +1273,11 @@ def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> Frame

self._validate_agg_columns(numeric_only=numeric_only, function_name="prod")

if min_count > 0:

def prod(col: Column) -> Column:
return F.when(
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
).otherwise(SF.product(col, True))

else:

def prod(col: Column) -> Column:
return SF.product(col, True)

return self._reduce_for_stat_function(
prod,
lambda col: SF.product(col, True),
accepted_spark_types=(NumericType, BooleanType),
bool_to_numeric=True,
min_count=min_count,
)

def all(self, skipna: bool = True) -> FrameLike:
Expand Down Expand Up @@ -3596,6 +3534,7 @@ def _reduce_for_stat_function(
sfun: Callable[[Column], Column],
accepted_spark_types: Optional[Tuple[Type[DataType], ...]] = None,
bool_to_numeric: bool = False,
**kwargs: Any,
) -> FrameLike:
"""Apply an aggregate function `sfun` per column and reduce to a FrameLike.

Expand All @@ -3615,14 +3554,19 @@ def _reduce_for_stat_function(
psdf: DataFrame = DataFrame(internal)

if len(psdf._internal.column_labels) > 0:
min_count = kwargs.get("min_count", 0)
stat_exprs = []
for label in psdf._internal.column_labels:
psser = psdf._psser_for(label)
stat_exprs.append(
sfun(psser._dtype_op.nan_to_null(psser).spark.column).alias(
psser._internal.data_spark_column_names[0]
input_scol = psser._dtype_op.nan_to_null(psser).spark.column
output_scol = sfun(input_scol)

if min_count > 0:
output_scol = F.when(
F.count(F.when(~F.isnull(input_scol), F.lit(0))) >= min_count, output_scol
)
)

stat_exprs.append(output_scol.alias(psser._internal.data_spark_column_names[0]))
sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
else:
sdf = sdf.select(*groupkey_names).distinct()
Expand Down