diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 85a7e0819cff..bc5c17a1c2ce 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -2524,14 +2524,14 @@ test_that("describe() and summary() on a DataFrame", { stats2 <- summary(df) expect_equal(collect(stats2)[5, "summary"], "25%") - expect_equal(collect(stats2)[5, "age"], "30.0") + expect_equal(collect(stats2)[5, "age"], "30") stats3 <- summary(df, "min", "max", "55.1%") expect_equal(collect(stats3)[1, "summary"], "min") expect_equal(collect(stats3)[2, "summary"], "max") expect_equal(collect(stats3)[3, "summary"], "55.1%") - expect_equal(collect(stats3)[3, "age"], "30.0") + expect_equal(collect(stats3)[3, "age"], "30") # SPARK-16425: SparkR summary() fails on column of type logical df <- withColumn(df, "boolean", df$age == 30) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5db60cc996e7..a095263bfa61 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1553,6 +1553,7 @@ options. ## Upgrading From Spark SQL 2.2 to 2.3 - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`. + - The `percentile_approx` function previously accepted numeric type input and output double type results. Now it supports date type, timestamp type and numeric types as input types. The result type is also changed to be the same as the input type, which is more reasonable for percentiles. ## Upgrading From Spark SQL 2.1 to 2.2 diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 88ac4134a0d0..bedb44417b97 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1037,9 +1037,9 @@ def summary(self, *statistics): | mean| 3.5| null| | stddev|2.1213203435596424| null| | min| 2|Alice| - | 25%| 5.0| null| - | 50%| 5.0| null| - | 75%| 5.0| null| + | 25%| 5| null| + | 50%| 5| null| + | 75%| 5| null| | max| 5| Bob| +-------+------------------+-----+ @@ -1049,8 +1049,8 @@ def summary(self, *statistics): +-------+---+-----+ | count| 2| 2| | min| 2|Alice| - | 25%|5.0| null| - | 75%|5.0| null| + | 25%| 5| null| + | 75%| 5| null| | max| 5| Bob| +-------+---+-----+ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 896c009b3297..7facb9dad9a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -85,7 +85,10 @@ case class ApproximatePercentile( private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int] override def inputTypes: Seq[AbstractDataType] = { - Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType) + // Support NumericType, DateType and TimestampType since their internal types are all numeric, + // and can be easily cast to double for processing. + Seq(TypeCollection(NumericType, DateType, TimestampType), + TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType) } // Mark as lazy so that percentageExpression is not evaluated during tree transformation. @@ -123,7 +126,15 @@ case class ApproximatePercentile( val value = child.eval(inputRow) // Ignore empty rows, for example: percentile_approx(null) if (value != null) { - buffer.add(value.asInstanceOf[Double]) + // Convert the value to a double value + val doubleValue = child.dataType match { + case DateType => value.asInstanceOf[Int].toDouble + case TimestampType => value.asInstanceOf[Long].toDouble + case n: NumericType => n.numeric.toDouble(value.asInstanceOf[n.InternalType]) + case other: DataType => + throw new UnsupportedOperationException(s"Unexpected data type $other") + } + buffer.add(doubleValue) } buffer } @@ -134,7 +145,20 @@ case class ApproximatePercentile( } override def eval(buffer: PercentileDigest): Any = { - val result = buffer.getPercentiles(percentages) + val doubleResult = buffer.getPercentiles(percentages) + val result = child.dataType match { + case DateType => doubleResult.map(_.toInt) + case TimestampType => doubleResult.map(_.toLong) + case ByteType => doubleResult.map(_.toByte) + case ShortType => doubleResult.map(_.toShort) + case IntegerType => doubleResult.map(_.toInt) + case LongType => doubleResult.map(_.toLong) + case FloatType => doubleResult.map(_.toFloat) + case DoubleType => doubleResult + case _: DecimalType => doubleResult.map(Decimal(_)) + case other: DataType => + throw new UnsupportedOperationException(s"Unexpected data type $other") + } if (result.length == 0) { null } else if (returnPercentileArray) { @@ -155,8 +179,9 @@ case class ApproximatePercentile( // Returns null for empty inputs override def nullable: Boolean = true + // The result type is the same as the input type. override def dataType: DataType = { - if (returnPercentileArray) ArrayType(DoubleType, false) else DoubleType + if (returnPercentileArray) ArrayType(child.dataType, false) else child.dataType } override def prettyName: String = "percentile_approx" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala index fcb370ae8460..84b3cc79cef5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericInternalRow, Literal} @@ -270,7 +270,6 @@ class ApproximatePercentileSuite extends SparkFunSuite { percentageExpression = percentageExpression, accuracyExpression = Literal(100)) - val result = wrongPercentage.checkInputDataTypes() assert( wrongPercentage.checkInputDataTypes() match { case TypeCheckFailure(msg) if msg.contains("must be between 0.0 and 1.0") => true @@ -281,7 +280,6 @@ class ApproximatePercentileSuite extends SparkFunSuite { test("class ApproximatePercentile, automatically add type casting for parameters") { val testRelation = LocalRelation('a.int) - val analyzer = SimpleAnalyzer // Compatible accuracy types: Long type and decimal type val accuracyExpressions = Seq(Literal(1000L), DecimalLiteral(10000), Literal(123.0D)) @@ -299,7 +297,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { analyzed match { case Alias(agg: ApproximatePercentile, _) => assert(agg.resolved) - assert(agg.child.dataType == DoubleType) + assert(agg.child.dataType == IntegerType) assert(agg.percentageExpression.dataType == DoubleType || agg.percentageExpression.dataType == ArrayType(DoubleType, containsNull = false)) assert(agg.accuracyExpression.dataType == IntegerType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 62a75343a094..1aea33766407 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.test.SharedSQLContext /** @@ -67,6 +70,30 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { } } + test("percentile_approx, different column types") { + withTempView(table) { + val intSeq = 1 to 1000 + val data: Seq[(java.math.BigDecimal, Date, Timestamp)] = intSeq.map { i => + (new java.math.BigDecimal(i), DateTimeUtils.toJavaDate(i), DateTimeUtils.toJavaTimestamp(i)) + } + data.toDF("cdecimal", "cdate", "ctimestamp").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s"""SELECT + | percentile_approx(cdecimal, array(0.25, 0.5, 0.75D)), + | percentile_approx(cdate, array(0.25, 0.5, 0.75D)), + | percentile_approx(ctimestamp, array(0.25, 0.5, 0.75D)) + |FROM $table + """.stripMargin), + Row( + Seq("250.000000000000000000", "500.000000000000000000", "750.000000000000000000") + .map(i => new java.math.BigDecimal(i)), + Seq(250, 500, 750).map(DateTimeUtils.toJavaDate), + Seq(250, 500, 750).map(i => DateTimeUtils.toJavaTimestamp(i.toLong))) + ) + } + } + test("percentile_approx, multiple records with the minimum value in a partition") { withTempView(table) { spark.sparkContext.makeRDD(Seq(1, 1, 2, 1, 1, 3, 1, 1, 4, 1, 1, 5), 4).toDF("col") @@ -88,7 +115,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { val accuracies = Array(1, 10, 100, 1000, 10000) val errors = accuracies.map { accuracy => val df = spark.sql(s"SELECT percentile_approx(col, 0.25, $accuracy) FROM $table") - val approximatePercentile = df.collect().head.getDouble(0) + val approximatePercentile = df.collect().head.getInt(0) val error = Math.abs(approximatePercentile - expectedPercentile) error } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 13341645e8ff..6178661cf7b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -803,9 +803,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("mean", null, "33.0", "178.0"), Row("stddev", null, "19.148542155126762", "11.547005383792516"), Row("min", "Alice", "16", "164"), - Row("25%", null, "24.0", "176.0"), - Row("50%", null, "24.0", "176.0"), - Row("75%", null, "32.0", "180.0"), + Row("25%", null, "24", "176"), + Row("50%", null, "24", "176"), + Row("75%", null, "32", "180"), Row("max", "David", "60", "192")) val emptySummaryResult = Seq(