Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -2524,14 +2524,14 @@ test_that("describe() and summary() on a DataFrame", {

stats2 <- summary(df)
expect_equal(collect(stats2)[5, "summary"], "25%")
expect_equal(collect(stats2)[5, "age"], "30.0")
expect_equal(collect(stats2)[5, "age"], "30")

stats3 <- summary(df, "min", "max", "55.1%")

expect_equal(collect(stats3)[1, "summary"], "min")
expect_equal(collect(stats3)[2, "summary"], "max")
expect_equal(collect(stats3)[3, "summary"], "55.1%")
expect_equal(collect(stats3)[3, "age"], "30.0")
expect_equal(collect(stats3)[3, "age"], "30")

# SPARK-16425: SparkR summary() fails on column of type logical
df <- withColumn(df, "boolean", df$age == 30)
Expand Down
1 change: 1 addition & 0 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,7 @@ options.
## Upgrading From Spark SQL 2.2 to 2.3

- Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`.
- The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles.

## Upgrading From Spark SQL 2.1 to 2.2

Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,9 +1037,9 @@ def summary(self, *statistics):
| mean| 3.5| null|
| stddev|2.1213203435596424| null|
| min| 2|Alice|
| 25%| 5.0| null|
| 50%| 5.0| null|
| 75%| 5.0| null|
| 25%| 5| null|
| 50%| 5| null|
| 75%| 5| null|
| max| 5| Bob|
+-------+------------------+-----+

Expand All @@ -1049,8 +1049,8 @@ def summary(self, *statistics):
+-------+---+-----+
| count| 2| 2|
| min| 2|Alice|
| 25%|5.0| null|
| 75%|5.0| null|
| 25%| 5| null|
| 75%| 5| null|
| max| 5| Bob|
+-------+---+-----+

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ case class ApproximatePercentile(
private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int]

override def inputTypes: Seq[AbstractDataType] = {
Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType)
// Support NumericType, DateType and TimestampType since their internal types are all numeric,
// and can be easily cast to double for processing.
Seq(TypeCollection(NumericType, DateType, TimestampType),
TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType)
}

// Mark as lazy so that percentageExpression is not evaluated during tree transformation.
Expand Down Expand Up @@ -123,7 +126,15 @@ case class ApproximatePercentile(
val value = child.eval(inputRow)
// Ignore empty rows, for example: percentile_approx(null)
if (value != null) {
buffer.add(value.asInstanceOf[Double])
// Convert the value to a double value
val doubleValue = child.dataType match {
case DateType => value.asInstanceOf[Int].toDouble
case TimestampType => value.asInstanceOf[Long].toDouble
case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same here.

case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
}
buffer.add(doubleValue)
}
buffer
}
Expand All @@ -134,7 +145,20 @@ case class ApproximatePercentile(
}

override def eval(buffer: PercentileDigest): Any = {
val result = buffer.getPercentiles(percentages)
val doubleResult = buffer.getPercentiles(percentages)
val result = child.dataType match {
case DateType => doubleResult.map(_.toInt)
case TimestampType => doubleResult.map(_.toLong)
case ByteType => doubleResult.map(_.toByte)
case ShortType => doubleResult.map(_.toShort)
case IntegerType => doubleResult.map(_.toInt)
case LongType => doubleResult.map(_.toLong)
case FloatType => doubleResult.map(_.toFloat)
case DoubleType => doubleResult
case _: DecimalType => doubleResult.map(Decimal(_))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add

        case other: DataType =>
          throw new UnsupportedOperationException(s"Unexpected data type $other")

case other: DataType =>
throw new UnsupportedOperationException(s"Unexpected data type $other")
}
if (result.length == 0) {
null
} else if (returnPercentileArray) {
Expand All @@ -155,8 +179,9 @@ case class ApproximatePercentile(
// Returns null for empty inputs
override def nullable: Boolean = true

// The result type is the same as the input type.
override def dataType: DataType = {
if (returnPercentileArray) ArrayType(DoubleType, false) else DoubleType
if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType
}

override def prettyName: String = "percentile_approx"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericInternalRow, Literal}
Expand Down Expand Up @@ -270,7 +270,6 @@ class ApproximatePercentileSuite extends SparkFunSuite {
percentageExpression = percentageExpression,
accuracyExpression = Literal(100))

val result = wrongPercentage.checkInputDataTypes()
Copy link
Contributor Author

@wzhfy wzhfy Sep 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is duplicated by line 274.

assert(
wrongPercentage.checkInputDataTypes() match {
case TypeCheckFailure(msg) if msg.contains("must be between 0.0 and 1.0") => true
Expand All @@ -281,7 +280,6 @@ class ApproximatePercentileSuite extends SparkFunSuite {

test("class ApproximatePercentile, automatically add type casting for parameters") {
val testRelation = LocalRelation('a.int)
val analyzer = SimpleAnalyzer

// Compatible accuracy types: Long type and decimal type
val accuracyExpressions = Seq(Literal(1000L), DecimalLiteral(10000), Literal(123.0D))
Expand All @@ -299,7 +297,7 @@ class ApproximatePercentileSuite extends SparkFunSuite {
analyzed match {
case Alias(agg: ApproximatePercentile, _) =>
assert(agg.resolved)
assert(agg.child.dataType == DoubleType)
assert(agg.child.dataType == IntegerType)
assert(agg.percentageExpression.dataType == DoubleType ||
agg.percentageExpression.dataType == ArrayType(DoubleType, containsNull = false))
assert(agg.accuracyExpression.dataType == IntegerType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql

import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.test.SharedSQLContext

/**
Expand Down Expand Up @@ -67,6 +70,30 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext {
}
}

test("percentile_approx, different column types") {
withTempView(table) {
val intSeq = 1 to 1000
val data: Seq[(java.math.BigDecimal, Date, Timestamp)] = intSeq.map { i =>
(new java.math.BigDecimal(i), DateTimeUtils.toJavaDate(i), DateTimeUtils.toJavaTimestamp(i))
}
data.toDF("cdecimal", "cdate", "ctimestamp").createOrReplaceTempView(table)
checkAnswer(
spark.sql(
s"""SELECT
| percentile_approx(cdecimal, array(0.25, 0.5, 0.75D)),
| percentile_approx(cdate, array(0.25, 0.5, 0.75D)),
| percentile_approx(ctimestamp, array(0.25, 0.5, 0.75D))
|FROM $table
""".stripMargin),
Row(
Seq("250.000000000000000000", "500.000000000000000000", "750.000000000000000000")
.map(i => new java.math.BigDecimal(i)),
Seq(250, 500, 750).map(DateTimeUtils.toJavaDate),
Seq(250, 500, 750).map(i => DateTimeUtils.toJavaTimestamp(i.toLong)))
)
}
}

test("percentile_approx, multiple records with the minimum value in a partition") {
withTempView(table) {
spark.sparkContext.makeRDD(Seq(1, 1, 2, 1, 1, 3, 1, 1, 4, 1, 1, 5), 4).toDF("col")
Expand All @@ -88,7 +115,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext {
val accuracies = Array(1, 10, 100, 1000, 10000)
val errors = accuracies.map { accuracy =>
val df = spark.sql(s"SELECT percentile_approx(col, 0.25, $accuracy) FROM $table")
val approximatePercentile = df.collect().head.getDouble(0)
val approximatePercentile = df.collect().head.getInt(0)
val error = Math.abs(approximatePercentile - expectedPercentile)
error
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row("mean", null, "33.0", "178.0"),
Row("stddev", null, "19.148542155126762", "11.547005383792516"),
Row("min", "Alice", "16", "164"),
Row("25%", null, "24.0", "176.0"),
Row("50%", null, "24.0", "176.0"),
Row("75%", null, "32.0", "180.0"),
Row("25%", null, "24", "176"),
Row("50%", null, "24", "176"),
Row("75%", null, "32", "180"),
Row("max", "David", "60", "192"))

val emptySummaryResult = Seq(
Expand Down