Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 3 additions & 20 deletions python/pyspark/pandas/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from pyspark.sql.types import (
BooleanType,
DoubleType,
IntegralType,
LongType,
NumericType,
)
Expand Down Expand Up @@ -1421,32 +1420,16 @@ def product(
def prod(psser: "Series") -> Column:
spark_type = psser.spark.data_type
spark_column = psser.spark.column

if not skipna:
spark_column = F.when(spark_column.isNull(), np.nan).otherwise(spark_column)

if isinstance(spark_type, BooleanType):
scol = F.min(F.coalesce(spark_column, F.lit(True))).cast(LongType())
elif isinstance(spark_type, NumericType):
num_zeros = F.sum(F.when(spark_column == 0, 1).otherwise(0))
sign = F.when(
F.sum(F.when(spark_column < 0, 1).otherwise(0)) % 2 == 0, 1
).otherwise(-1)

scol = F.when(num_zeros > 0, 0).otherwise(
sign * F.exp(F.sum(F.log(F.abs(spark_column))))
)

if isinstance(spark_type, IntegralType):
scol = F.round(scol).cast(LongType())
else:
spark_column = spark_column.cast(LongType())
elif not isinstance(spark_type, NumericType):
raise TypeError(
"Could not convert {} ({}) to numeric".format(
spark_type_to_pandas_dtype(spark_type), spark_type.simpleString()
)
)

return F.coalesce(scol, F.lit(1))
return SF.product(spark_column, skipna)

return self._reduce_for_stat_function(
prod,
Expand Down
53 changes: 14 additions & 39 deletions python/pyspark/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
StructField,
StructType,
StringType,
IntegralType,
)

from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
Expand Down Expand Up @@ -1320,52 +1319,28 @@ def prod(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> Frame
1 NaN 2.0 0.0
2 NaN NaN NaN
"""
if not isinstance(min_count, int):
raise TypeError("min_count must be integer")

self._validate_agg_columns(numeric_only=numeric_only, function_name="prod")

groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(self._groupkeys))]
internal, agg_columns, sdf = self._prepare_reduce(
groupkey_names=groupkey_names,
accepted_spark_types=(NumericType, BooleanType),
bool_to_numeric=True,
)

psdf: DataFrame = DataFrame(internal)
if len(psdf._internal.column_labels) > 0:

stat_exprs = []
for label in psdf._internal.column_labels:
psser = psdf._psser_for(label)
column = psser._dtype_op.nan_to_null(psser).spark.column
data_type = psser.spark.data_type
aggregating = (
F.product(column).cast("long")
if isinstance(data_type, IntegralType)
else F.product(column)
)

if min_count > 0:
prod_scol = F.when(
F.count(F.when(~F.isnull(column), F.lit(0))) < min_count, F.lit(None)
).otherwise(aggregating)
else:
prod_scol = aggregating

stat_exprs.append(prod_scol.alias(psser._internal.data_spark_column_names[0]))
if min_count > 0:

sdf = sdf.groupby(*groupkey_names).agg(*stat_exprs)
def prod(col: Column) -> Column:
return F.when(
F.count(F.when(~F.isnull(col), F.lit(0))) < min_count, F.lit(None)
).otherwise(SF.product(col, True))

else:
sdf = sdf.select(*groupkey_names).distinct()

internal = internal.copy(
spark_frame=sdf,
index_spark_columns=[scol_for(sdf, col) for col in groupkey_names],
data_spark_columns=[scol_for(sdf, col) for col in internal.data_spark_column_names],
data_fields=None,
)
def prod(col: Column) -> Column:
return SF.product(col, True)

return self._prepare_return(DataFrame(internal))
return self._reduce_for_stat_function(
prod,
accepted_spark_types=(NumericType, BooleanType),
bool_to_numeric=True,
)

def all(self, skipna: bool = True) -> FrameLike:
"""
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
)


def product(col: Column, dropna: bool) -> Column:
sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))


def stddev(col: Column, ddof: int) -> Column:
sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/pandas/tests/test_generic_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,22 @@ def test_stat_functions(self):
self.assert_eq(pdf.b.kurtosis(), psdf.b.kurtosis())
self.assert_eq(pdf.c.kurtosis(), psdf.c.kurtosis())

def test_prod_precision(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test will fail in original implementation due to precision loss

the new implementation will provide the precise result for integral inputs, when the product in [Long.MinValue, Long.MaxValue]

pdf = pd.DataFrame(
{
"a": [np.nan, np.nan, np.nan, np.nan],
"b": [1, np.nan, np.nan, -4],
"c": [1, -2, 3, -4],
"d": [55108, 55108, 55108, 55108],
}
)
psdf = ps.from_pandas(pdf)

self.assert_eq(pdf.prod(), psdf.prod())
self.assert_eq(pdf.prod(skipna=False), psdf.prod(skipna=False))
self.assert_eq(pdf.prod(min_count=3), psdf.prod(min_count=3))
self.assert_eq(pdf.prod(skipna=False, min_count=3), psdf.prod(skipna=False, min_count=3))


if __name__ == "__main__":
import unittest
Expand Down
28 changes: 25 additions & 3 deletions python/pyspark/pandas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,13 +1493,35 @@ def test_nth(self):
self.psdf.groupby("B").nth("x")

def test_prod(self):
pdf = pd.DataFrame(
{
"A": [1, 2, 1, 2, 1],
"B": [3.1, 4.1, 4.1, 3.1, 0.1],
"C": ["a", "b", "b", "a", "c"],
"D": [True, False, False, True, False],
"E": [-1, -2, 3, -4, -2],
"F": [-1.5, np.nan, -3.2, 0.1, 0],
"G": [np.nan, np.nan, np.nan, np.nan, np.nan],
}
)
psdf = ps.from_pandas(pdf)

for n in [0, 1, 2, 128, -1, -2, -128]:
self._test_stat_func(lambda groupby_obj: groupby_obj.prod(min_count=n))
self._test_stat_func(
lambda groupby_obj: groupby_obj.prod(numeric_only=None, min_count=n)
lambda groupby_obj: groupby_obj.prod(min_count=n), check_exact=False
)
self._test_stat_func(
lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n)
lambda groupby_obj: groupby_obj.prod(numeric_only=None, min_count=n),
check_exact=False,
)
self._test_stat_func(
lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n),
check_exact=False,
)
self.assert_eq(
pdf.groupby("A").prod(min_count=n).sort_index(),
psdf.groupby("A").prod(min_count=n).sort_index(),
almost=True,
)

def test_cumcount(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ImplicitCastInputTypes, Literal}
import org.apache.spark.sql.catalyst.expressions.{Abs, AttributeReference, Exp, Expression, If, ImplicitCastInputTypes, IsNull, Literal, Log}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.{AbstractDataType, DataType, DoubleType}
import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, DoubleType, IntegralType, LongType, NumericType}


/** Multiply numerical values within an aggregation group */
Expand Down Expand Up @@ -63,3 +63,114 @@ case class Product(child: Expression)
override protected def withNewChildInternal(newChild: Expression): Product =
copy(child = newChild)
}

/**
* Product in Pandas' fashion. This expression is dedicated only for Pandas API on Spark.
* It has three main differences from `Product`:
* 1, it compute the product of `Fractional` inputs in a more numerical-stable way;
* 2, it compute the product of `Integral` inputs with LongType variables internally;
* 3, it accepts NULLs when `ignoreNA` is False;
*/
case class PandasProduct(
child: Expression,
ignoreNA: Boolean)
extends DeclarativeAggregate with ImplicitCastInputTypes with UnaryLike[Expression] {

override def nullable: Boolean = !ignoreNA

override def dataType: DataType = child.dataType match {
case _: IntegralType => LongType
case _ => DoubleType
}

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

private lazy val product =
AttributeReference("product", LongType, nullable = false)()
private lazy val logSum =
AttributeReference("logSum", DoubleType, nullable = false)()
private lazy val positive =
AttributeReference("positive", BooleanType, nullable = false)()
private lazy val containsZero =
AttributeReference("containsZero", BooleanType, nullable = false)()
private lazy val containsNull =
AttributeReference("containsNull", BooleanType, nullable = false)()

override lazy val aggBufferAttributes = child.dataType match {
case _: IntegralType =>
Seq(product, containsNull)
case _ =>
Seq(logSum, positive, containsZero, containsNull)
}

override lazy val initialValues: Seq[Expression] = child.dataType match {
case _: IntegralType =>
Seq(Literal(1L), Literal(false))
case _ =>
Seq(Literal(0.0), Literal(true), Literal(false), Literal(false))
}

override lazy val updateExpressions: Seq[Expression] = child.dataType match {
case _: IntegralType =>
Seq(
If(IsNull(child), product, product * child),
containsNull || IsNull(child)
)
case _ =>
val newLogSum = logSum + Log(Abs(child))
val newPositive = If(child < Literal(0.0), !positive, positive)
val newContainsZero = containsZero || child <=> Literal(0.0)
val newContainsNull = containsNull || IsNull(child)
if (ignoreNA) {
Seq(
If(IsNull(child) || newContainsZero, logSum, newLogSum),
newPositive,
newContainsZero,
newContainsNull
)
} else {
Seq(
If(newContainsNull || newContainsZero, logSum, newLogSum),
newPositive,
newContainsZero,
newContainsNull
)
}
}

override lazy val mergeExpressions: Seq[Expression] = child.dataType match {
case _: IntegralType =>
Seq(
product.left * product.right,
containsNull.left || containsNull.right
)
case _ =>
Seq(
logSum.left + logSum.right,
positive.left === positive.right,
containsZero.left || containsZero.right,
containsNull.left || containsNull.right
)
}

override lazy val evaluateExpression: Expression = child.dataType match {
case _: IntegralType =>
if (ignoreNA) {
product
} else {
If(containsNull, Literal(null, LongType), product)
}
case _ =>
val product = If(positive, Exp(logSum), -Exp(logSum))
if (ignoreNA) {
If(containsZero, Literal(0.0), product)
} else {
If(containsNull, Literal(null, DoubleType),
If(containsZero, Literal(0.0), product))
}
}

override def prettyName: String = "pandas_product"
override protected def withNewChildInternal(newChild: Expression): PandasProduct =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ private[sql] object PythonSQLUtils extends Logging {
Column(TimestampDiff(unit, start.expr, end.expr))
}

def pandasProduct(e: Column, ignoreNA: Boolean): Column = {
Column(PandasProduct(e.expr, ignoreNA).toAggregateExpression(false))
}

def pandasStddev(e: Column, ddof: Int): Column = {
Column(PandasStddev(e.expr, ddof).toAggregateExpression(false))
}
Expand Down