diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index b2525ce9a60ad..c5dbcb79710a5 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -468,21 +468,10 @@ def first(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fr if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def first(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.first(col, ignorenulls=True)) - - else: - - def first(col: Column) -> Column: - return F.first(col, ignorenulls=True) - return self._reduce_for_stat_function( - first, + lambda col: F.first(col, ignorenulls=True), accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike: @@ -549,21 +538,10 @@ def last(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fra if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def last(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.last(col, ignorenulls=True)) - - else: - - def last(col: Column) -> Column: - return F.last(col, ignorenulls=True) - return self._reduce_for_stat_function( - last, + lambda col: F.last(col, ignorenulls=True), accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> FrameLike: @@ -624,20 +602,10 @@ def max(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def max(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.max(col)) - - else: - - def max(col: Column) -> Column: - return F.max(col) - return self._reduce_for_stat_function( - max, accepted_spark_types=(NumericType, BooleanType) if numeric_only else None + F.max, + accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) def mean(self, numeric_only: Optional[bool] = True) -> FrameLike: @@ -802,20 +770,10 @@ def min(self, numeric_only: Optional[bool] = False, min_count: int = -1) -> Fram if not isinstance(min_count, int): raise TypeError("min_count must be integer") - if min_count > 0: - - def min(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.min(col)) - - else: - - def min(col: Column) -> Column: - return F.min(col) - return self._reduce_for_stat_function( - min, accepted_spark_types=(NumericType, BooleanType) if numeric_only else None + F.min, + accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + min_count=min_count, ) # TODO: sync the doc. @@ -944,20 +902,11 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL f"numeric_only=False, skip unsupported columns: {unsupported}" ) - if min_count > 0: - - def sum(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(F.sum(col)) - - else: - - def sum(col: Column) -> Column: - return F.sum(col) - return self._reduce_for_stat_function( - sum, accepted_spark_types=(NumericType,), bool_to_numeric=True + F.sum, + accepted_spark_types=(NumericType, BooleanType), + bool_to_numeric=True, + min_count=min_count, ) # TODO: sync the doc. @@ -1324,22 +1273,11 @@ def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> Frame self._validate_agg_columns(numeric_only=numeric_only, function_name="prod") - if min_count > 0: - - def prod(col: Column) -> Column: - return F.when( - F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) - ).otherwise(SF.product(col, True)) - - else: - - def prod(col: Column) -> Column: - return SF.product(col, True) - return self._reduce_for_stat_function( - prod, + lambda col: SF.product(col, True), accepted_spark_types=(NumericType, BooleanType), bool_to_numeric=True, + min_count=min_count, ) def all(self, skipna: bool = True) -> FrameLike: @@ -3596,6 +3534,7 @@ def _reduce_for_stat_function( sfun: Callable[[Column], Column], accepted_spark_types: Optional[Tuple[Type[DataType], ...]] = None, bool_to_numeric: bool = False, + **kwargs: Any, ) -> FrameLike: """Apply an aggregate function `sfun` per column and reduce to a FrameLike. @@ -3615,14 +3554,19 @@ def _reduce_for_stat_function( psdf: DataFrame = DataFrame(internal) if len(psdf._internal.column_labels) > 0: + min_count = kwargs.get("min_count", 0) stat_exprs = [] for label in psdf._internal.column_labels: psser = psdf._psser_for(label) - stat_exprs.append( - sfun(psser._dtype_op.nan_to_null(psser).spark.column).alias( - psser._internal.data_spark_column_names[0] + input_scol = psser._dtype_op.nan_to_null(psser).spark.column + output_scol = sfun(input_scol) + + if min_count > 0: + output_scol = F.when( + F.count(F.when(~F.isnull(input_scol), F.lit(0))) >= min_count, output_scol ) - ) + + stat_exprs.append(output_scol.alias(psser._internal.data_spark_column_names[0])) sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs) else: sdf = sdf.select(*groupkey_names).distinct()