diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py index 5b94da4412551..9db688d913462 100644 --- a/python/pyspark/pandas/generic.py +++ b/python/pyspark/pandas/generic.py @@ -45,7 +45,6 @@ from pyspark.sql.types import ( BooleanType, DoubleType, - IntegralType, LongType, NumericType, ) @@ -1421,32 +1420,16 @@ def product( def prod(psser: "Series") -> Column: spark_type = psser.spark.data_type spark_column = psser.spark.column - - if not skipna: - spark_column = F.when(spark_column.isNull(), np.nan).otherwise(spark_column) - if isinstance(spark_type, BooleanType): - scol = F.min(F.coalesce(spark_column, F.lit(True))).cast(LongType()) - elif isinstance(spark_type, NumericType): - num_zeros = F.sum(F.when(spark_column == 0, 1).otherwise(0)) - sign = F.when( - F.sum(F.when(spark_column < 0, 1).otherwise(0)) % 2 == 0, 1 - ).otherwise(-1) - - scol = F.when(num_zeros > 0, 0).otherwise( - sign * F.exp(F.sum(F.log(F.abs(spark_column)))) - ) - - if isinstance(spark_type, IntegralType): - scol = F.round(scol).cast(LongType()) - else: + spark_column = spark_column.cast(LongType()) + elif not isinstance(spark_type, NumericType): raise TypeError( "Could not convert {} ({}) to numeric".format( spark_type_to_pandas_dtype(spark_type), spark_type.simpleString() ) ) - return F.coalesce(scol, F.lit(1)) + return SF.product(spark_column, skipna) return self._reduce_for_stat_function( prod, diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 08a136aa26891..b2525ce9a60ad 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -62,7 +62,6 @@ StructField, StructType, StringType, - IntegralType, ) from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. @@ -1320,52 +1319,28 @@ def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> Frame 1 NaN 2.0 0.0 2 NaN NaN NaN """ + if not isinstance(min_count, int): + raise TypeError("min_count must be integer") self._validate_agg_columns(numeric_only=numeric_only, function_name="prod") - groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))] - internal, agg_columns, sdf = self._prepare_reduce( - groupkey_names=groupkey_names, - accepted_spark_types=(NumericType, BooleanType), - bool_to_numeric=True, - ) - - psdf: DataFrame = DataFrame(internal) - if len(psdf._internal.column_labels) > 0: - - stat_exprs = [] - for label in psdf._internal.column_labels: - psser = psdf._psser_for(label) - column = psser._dtype_op.nan_to_null(psser).spark.column - data_type = psser.spark.data_type - aggregating = ( - F.product(column).cast("long") - if isinstance(data_type, IntegralType) - else F.product(column) - ) - - if min_count > 0: - prod_scol = F.when( - F.count(F.when(~F.isnull(column), F.lit(0))) < min_count, F.lit(None) - ).otherwise(aggregating) - else: - prod_scol = aggregating - - stat_exprs.append(prod_scol.alias(psser._internal.data_spark_column_names[0])) + if min_count > 0: - sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs) + def prod(col: Column) -> Column: + return F.when( + F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None) + ).otherwise(SF.product(col, True)) else: - sdf = sdf.select(*groupkey_names).distinct() - internal = internal.copy( - spark_frame=sdf, - index_spark_columns=[scol_for(sdf, col) for col in groupkey_names], - data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names], - data_fields=None, - ) + def prod(col: Column) -> Column: + return SF.product(col, True) - return self._prepare_return(DataFrame(internal)) + return self._reduce_for_stat_function( + prod, + accepted_spark_types=(NumericType, BooleanType), + bool_to_numeric=True, + ) def all(self, skipna: bool = True) -> FrameLike: """ diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index f9311296a5724..658d3459b24f3 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -27,6 +27,11 @@ ) +def product(col: Column, dropna: bool) -> Column: + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna)) + + def stddev(col: Column, ddof: int) -> Column: sc = SparkContext._active_spark_context return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof)) diff --git a/python/pyspark/pandas/tests/test_generic_functions.py b/python/pyspark/pandas/tests/test_generic_functions.py index 7c252c8356d80..d476302205938 100644 --- a/python/pyspark/pandas/tests/test_generic_functions.py +++ b/python/pyspark/pandas/tests/test_generic_functions.py @@ -200,6 +200,22 @@ def test_stat_functions(self): self.assert_eq(pdf.b.kurtosis(), psdf.b.kurtosis()) self.assert_eq(pdf.c.kurtosis(), psdf.c.kurtosis()) + def test_prod_precision(self): + pdf = pd.DataFrame( + { + "a": [np.nan, np.nan, np.nan, np.nan], + "b": [1, np.nan, np.nan, -4], + "c": [1, -2, 3, -4], + "d": [55108, 55108, 55108, 55108], + } + ) + psdf = ps.from_pandas(pdf) + + self.assert_eq(pdf.prod(), psdf.prod()) + self.assert_eq(pdf.prod(skipna=False), psdf.prod(skipna=False)) + self.assert_eq(pdf.prod(min_count=3), psdf.prod(min_count=3)) + self.assert_eq(pdf.prod(skipna=False, min_count=3), psdf.prod(skipna=False, min_count=3)) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py index 33ba815564145..a203f77717e9d 100644 --- a/python/pyspark/pandas/tests/test_groupby.py +++ b/python/pyspark/pandas/tests/test_groupby.py @@ -1493,13 +1493,35 @@ def test_nth(self): self.psdf.groupby("B").nth("x") def test_prod(self): + pdf = pd.DataFrame( + { + "A": [1, 2, 1, 2, 1], + "B": [3.1, 4.1, 4.1, 3.1, 0.1], + "C": ["a", "b", "b", "a", "c"], + "D": [True, False, False, True, False], + "E": [-1, -2, 3, -4, -2], + "F": [-1.5, np.nan, -3.2, 0.1, 0], + "G": [np.nan, np.nan, np.nan, np.nan, np.nan], + } + ) + psdf = ps.from_pandas(pdf) + for n in [0, 1, 2, 128, -1, -2, -128]: - self._test_stat_func(lambda groupby_obj: groupby_obj.prod(min_count=n)) self._test_stat_func( - lambda groupby_obj: groupby_obj.prod(numeric_only=None, min_count=n) + lambda groupby_obj: groupby_obj.prod(min_count=n), check_exact=False ) self._test_stat_func( - lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n) + lambda groupby_obj: groupby_obj.prod(numeric_only=None, min_count=n), + check_exact=False, + ) + self._test_stat_func( + lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n), + check_exact=False, + ) + self.assert_eq( + pdf.groupby("A").prod(min_count=n).sort_index(), + psdf.groupby("A").prod(min_count=n).sort_index(), + almost=True, ) def test_cumcount(self): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala index 3af3944fd47d7..3325c8f16a4f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Product.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ImplicitCastInputTypes, Literal} +import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, Exp, Expression, If, ImplicitCastInputTypes, IsNull, Literal, Log} import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType} +import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, DoubleType, IntegralType, LongType, NumericType} /** Multiply numerical values within an aggregation group */ @@ -63,3 +63,114 @@ case class Product(child: Expression) override protected def withNewChildInternal(newChild: Expression): Product = copy(child = newChild) } + +/** + * Product in Pandas' fashion. This expression is dedicated only for Pandas API on Spark. + * It has three main differences from `Product`: + * 1, it compute the product of `Fractional` inputs in a more numerical-stable way; + * 2, it compute the product of `Integral` inputs with LongType variables internally; + * 3, it accepts NULLs when `ignoreNA` is False; + */ +case class PandasProduct( + child: Expression, + ignoreNA: Boolean) + extends DeclarativeAggregate with ImplicitCastInputTypes with UnaryLike[Expression] { + + override def nullable: Boolean = !ignoreNA + + override def dataType: DataType = child.dataType match { + case _: IntegralType => LongType + case _ => DoubleType + } + + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + private lazy val product = + AttributeReference("product", LongType, nullable = false)() + private lazy val logSum = + AttributeReference("logSum", DoubleType, nullable = false)() + private lazy val positive = + AttributeReference("positive", BooleanType, nullable = false)() + private lazy val containsZero = + AttributeReference("containsZero", BooleanType, nullable = false)() + private lazy val containsNull = + AttributeReference("containsNull", BooleanType, nullable = false)() + + override lazy val aggBufferAttributes = child.dataType match { + case _: IntegralType => + Seq(product, containsNull) + case _ => + Seq(logSum, positive, containsZero, containsNull) + } + + override lazy val initialValues: Seq[Expression] = child.dataType match { + case _: IntegralType => + Seq(Literal(1L), Literal(false)) + case _ => + Seq(Literal(0.0), Literal(true), Literal(false), Literal(false)) + } + + override lazy val updateExpressions: Seq[Expression] = child.dataType match { + case _: IntegralType => + Seq( + If(IsNull(child), product, product * child), + containsNull || IsNull(child) + ) + case _ => + val newLogSum = logSum + Log(Abs(child)) + val newPositive = If(child < Literal(0.0), !positive, positive) + val newContainsZero = containsZero || child <=> Literal(0.0) + val newContainsNull = containsNull || IsNull(child) + if (ignoreNA) { + Seq( + If(IsNull(child) || newContainsZero, logSum, newLogSum), + newPositive, + newContainsZero, + newContainsNull + ) + } else { + Seq( + If(newContainsNull || newContainsZero, logSum, newLogSum), + newPositive, + newContainsZero, + newContainsNull + ) + } + } + + override lazy val mergeExpressions: Seq[Expression] = child.dataType match { + case _: IntegralType => + Seq( + product.left * product.right, + containsNull.left || containsNull.right + ) + case _ => + Seq( + logSum.left + logSum.right, + positive.left === positive.right, + containsZero.left || containsZero.right, + containsNull.left || containsNull.right + ) + } + + override lazy val evaluateExpression: Expression = child.dataType match { + case _: IntegralType => + if (ignoreNA) { + product + } else { + If(containsNull, Literal(null, LongType), product) + } + case _ => + val product = If(positive, Exp(logSum), -Exp(logSum)) + if (ignoreNA) { + If(containsZero, Literal(0.0), product) + } else { + If(containsNull, Literal(null, DoubleType), + If(containsZero, Literal(0.0), product)) + } + } + + override def prettyName: String = "pandas_product" + override protected def withNewChildInternal(newChild: Expression): PandasProduct = + copy(child = newChild) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index d43a906067770..70474f4d5c43b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -155,6 +155,10 @@ private[sql] object PythonSQLUtils extends Logging { Column(TimestampDiff(unit, start.expr, end.expr)) } + def pandasProduct(e: Column, ignoreNA: Boolean): Column = { + Column(PandasProduct(e.expr, ignoreNA).toAggregateExpression(false)) + } + def pandasStddev(e: Column, ddof: Int): Column = { Column(PandasStddev(e.expr, ddof).toAggregateExpression(false)) }