Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -61,7 +61,7 @@ case class Percentile(
frequencyExpression : Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes {
extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes {

def this(child: Expression, percentageExpression: Expression) = {
this(child, percentageExpression, Literal(1L), 0, 0)
Expand Down Expand Up @@ -130,15 +130,20 @@ case class Percentile(
}
}

override def createAggregationBuffer(): OpenHashMap[Number, Long] = {
private def toDoubleValue(d: Any): Double = d match {
case d: Decimal => d.toDouble
case n: Number => n.doubleValue
}

override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
// Initialize new counts map instance here.
new OpenHashMap[Number, Long]()
new OpenHashMap[AnyRef, Long]()
}

override def update(
buffer: OpenHashMap[Number, Long],
input: InternalRow): OpenHashMap[Number, Long] = {
val key = child.eval(input).asInstanceOf[Number]
buffer: OpenHashMap[AnyRef, Long],
input: InternalRow): OpenHashMap[AnyRef, Long] = {
val key = child.eval(input).asInstanceOf[AnyRef]
val frqValue = frequencyExpression.eval(input)

// Null values are ignored in counts map.
Expand All @@ -155,32 +160,32 @@ case class Percentile(
}

override def merge(
buffer: OpenHashMap[Number, Long],
other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = {
buffer: OpenHashMap[AnyRef, Long],
other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = {
other.foreach { case (key, count) =>
buffer.changeValue(key, count, _ + count)
}
buffer
}

override def eval(buffer: OpenHashMap[Number, Long]): Any = {
override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
generateOutput(getPercentiles(buffer))
}

private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = {
private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = {
if (buffer.isEmpty) {
return Seq.empty
}

val sortedCounts = buffer.toSeq.sortBy(_._1)(
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]])
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
}.tail
val maxPosition = accumlatedCounts.last._2 - 1

percentages.map { percentile =>
getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue()
getPercentile(accumlatedCounts, maxPosition * percentile)
}
}

Expand All @@ -200,7 +205,7 @@ case class Percentile(
* This function has been based upon similar function from HIVE
* `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
*/
private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = {
private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = {
// We may need to do linear interpolation to get the exact percentile
val lower = position.floor.toLong
val higher = position.ceil.toLong
Expand All @@ -213,18 +218,17 @@ case class Percentile(
val lowerKey = aggreCounts(lowerIndex)._1
if (higher == lower) {
// no interpolation needed because position does not have a fraction
return lowerKey
return toDoubleValue(lowerKey)
}

val higherKey = aggreCounts(higherIndex)._1
if (higherKey == lowerKey) {
// no interpolation needed because lower position and higher position has the same key
return lowerKey
return toDoubleValue(lowerKey)
}

// Linear interpolation to get the exact percentile
return (higher - position) * lowerKey.doubleValue() +
(position - lower) * higherKey.doubleValue()
(higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey)
}

/**
Expand All @@ -238,7 +242,7 @@ case class Percentile(
}
}

override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = {
override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = {
val buffer = new Array[Byte](4 << 10) // 4K
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(bos)
Expand All @@ -261,11 +265,11 @@ case class Percentile(
}
}

override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = {
override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = {
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(bis)
try {
val counts = new OpenHashMap[Number, Long]
val counts = new OpenHashMap[AnyRef, Long]
// Read unsafeRow size and content in bytes.
var sizeOfNextRow = ins.readInt()
while (sizeOfNextRow >= 0) {
Expand All @@ -274,7 +278,7 @@ case class Percentile(
val row = new UnsafeRow(2)
row.pointTo(bs, sizeOfNextRow)
// Insert the pairs into counts map.
val key = row.get(0, child.dataType).asInstanceOf[Number]
val key = row.get(0, child.dataType)
val count = row.get(1, LongType).asInstanceOf[Long]
counts.update(key, count)
sizeOfNextRow = ins.readInt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._
Expand All @@ -39,12 +38,12 @@ class PercentileSuite extends SparkFunSuite {
val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5))

// Check empty serialize and deserialize
val buffer = new OpenHashMap[Number, Long]()
val buffer = new OpenHashMap[AnyRef, Long]()
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))

// Check non-empty buffer serializa and deserialize.
data.foreach { key =>
buffer.changeValue(key, 1L, _ + 1L)
buffer.changeValue(new Integer(key), 1L, _ + 1L)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To we need to explicitly type this? I thoughtscala boxed automatically.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this boxing does not exist, it throws an exception below;

[error] /Users/maropu/IdeaProjects/spark/spark-master/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Perc
entileSuite.scala:46: the result type of an implicit conversion must be more specific than AnyRef

}
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
}
Expand All @@ -58,25 +57,25 @@ class PercentileSuite extends SparkFunSuite {
val agg = new Percentile(childExpression, percentageExpression)

// Test with rows without frequency
val rows = (1 to count).map( x => Seq(x))
runTest( agg, rows, expectedPercentiles)
val rows = (1 to count).map(x => Seq(x))
runTest(agg, rows, expectedPercentiles)

// Test with row with frequency. Second and third columns are frequency in Int and Long
val countForFrequencyTest = 1000
val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong)
val rowsWithFrequency = (1 to countForFrequencyTest).map(x => Seq(x, x):+ x.toLong)
val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0)

val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false)
val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt)
runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
runTest(aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)

val frequencyExpressionLong = BoundReference(2, LongType, nullable = false)
val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong)
runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
runTest(aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)

// Run test with Flatten data
val flattenRows = (1 to countForFrequencyTest).flatMap( current =>
(1 to current).map( y => current )).map( Seq(_))
val flattenRows = (1 to countForFrequencyTest).flatMap(current =>
(1 to current).map(y => current )).map(Seq(_))
runTest(agg, flattenRows, expectedPercentilesWithFrquency)
}

Expand Down Expand Up @@ -153,7 +152,7 @@ class PercentileSuite extends SparkFunSuite {
}

val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType)
for ( dataType <- validDataTypes;
for (dataType <- validDataTypes;
frequencyType <- validFrequencyTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
Expand All @@ -176,7 +175,7 @@ class PercentileSuite extends SparkFunSuite {
StringType, DateType, TimestampType,
CalendarIntervalType, NullType)

for( dataType <- invalidDataTypes;
for(dataType <- invalidDataTypes;
frequencyType <- validFrequencyTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
Expand All @@ -186,7 +185,7 @@ class PercentileSuite extends SparkFunSuite {
s"'`a`' is of ${dataType.simpleString} type."))
}

for( dataType <- validDataTypes;
for(dataType <- validDataTypes;
frequencyType <- invalidFrequencyDataTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
Expand Down Expand Up @@ -294,11 +293,11 @@ class PercentileSuite extends SparkFunSuite {
agg.update(buffer, InternalRow(1, -5))
agg.eval(buffer)
}
assert( caught.getMessage.startsWith("Negative values found in "))
assert(caught.getMessage.startsWith("Negative values found in "))
}

private def compareEquals(
left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = {
left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = {
left.size == right.size && left.forall { case (key, count) =>
right.apply(key) == count
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1702,4 +1702,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = Seq(123L -> "123", 19157170390056973L -> "19157170390056971").toDF("i", "j")
checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil)
}

test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") {
val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)")
checkAnswer(df, Row(BigDecimal(0.0)) :: Nil)
}
}