diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index 096d1b35a8620..d4421ca20a9bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -22,9 +22,10 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExpectsInputTypes, Expression} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform /** * This function counts the approximate number of distinct values (ndv) in @@ -46,16 +47,7 @@ case class ApproxCountDistinctForIntervals( relativeSD: Double = 0.05, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with ExpectsInputTypes { - - def this(child: Expression, endpointsExpression: Expression) = { - this( - child = child, - endpointsExpression = endpointsExpression, - relativeSD = 0.05, - mutableAggBufferOffset = 0, - inputAggBufferOffset = 0) - } + extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes { def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = { this( @@ -114,29 +106,11 @@ case class ApproxCountDistinctForIntervals( private lazy val totalNumWords = numWordsPerHllpp * hllppArray.length /** Allocate enough words to store all registers. */ - override lazy val aggBufferAttributes: Seq[AttributeReference] = { - Seq.tabulate(totalNumWords) { i => - AttributeReference(s"MS[$i]", LongType)() - } + override def createAggregationBuffer(): Array[Long] = { + Array.fill(totalNumWords)(0L) } - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override lazy val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - - /** Fill all words with zeros. */ - override def initialize(buffer: InternalRow): Unit = { - var word = 0 - while (word < totalNumWords) { - buffer.setLong(mutableAggBufferOffset + word, 0) - word += 1 - } - } - - override def update(buffer: InternalRow, input: InternalRow): Unit = { + override def update(buffer: Array[Long], input: InternalRow): Array[Long] = { val value = child.eval(input) // Ignore empty rows if (value != null) { @@ -153,13 +127,14 @@ case class ApproxCountDistinctForIntervals( // endpoints are sorted into ascending order already if (endpoints.head > doubleValue || endpoints.last < doubleValue) { // ignore if the value is out of the whole range - return + return buffer } val hllppIndex = findHllppIndex(doubleValue) - val offset = mutableAggBufferOffset + hllppIndex * numWordsPerHllpp - hllppArray(hllppIndex).update(buffer, offset, value, child.dataType) + val offset = hllppIndex * numWordsPerHllpp + hllppArray(hllppIndex).update(LongArrayInternalRow(buffer), offset, value, child.dataType) } + buffer } // Find which interval (HyperLogLogPlusPlusHelper) should receive the given value. @@ -196,17 +171,18 @@ case class ApproxCountDistinctForIntervals( } } - override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: Array[Long], buffer2: Array[Long]): Array[Long] = { for (i <- hllppArray.indices) { hllppArray(i).merge( - buffer1 = buffer1, - buffer2 = buffer2, - offset1 = mutableAggBufferOffset + i * numWordsPerHllpp, - offset2 = inputAggBufferOffset + i * numWordsPerHllpp) + buffer1 = LongArrayInternalRow(buffer1), + buffer2 = LongArrayInternalRow(buffer2), + offset1 = i * numWordsPerHllpp, + offset2 = i * numWordsPerHllpp) } + buffer1 } - override def eval(buffer: InternalRow): Any = { + override def eval(buffer: Array[Long]): Any = { val ndvArray = hllppResults(buffer) // If the endpoints contains multiple elements with the same value, // we set ndv=1 for intervals between these elements. @@ -218,19 +194,23 @@ case class ApproxCountDistinctForIntervals( new GenericArrayData(ndvArray) } - def hllppResults(buffer: InternalRow): Array[Long] = { + def hllppResults(buffer: Array[Long]): Array[Long] = { val ndvArray = new Array[Long](hllppArray.length) for (i <- ndvArray.indices) { - ndvArray(i) = hllppArray(i).query(buffer, mutableAggBufferOffset + i * numWordsPerHllpp) + ndvArray(i) = hllppArray(i).query(LongArrayInternalRow(buffer), i * numWordsPerHllpp) } ndvArray } - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int) + : ApproxCountDistinctForIntervals = { copy(mutableAggBufferOffset = newMutableAggBufferOffset) + } - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int) + : ApproxCountDistinctForIntervals = { copy(inputAggBufferOffset = newInputAggBufferOffset) + } override def children: Seq[Expression] = Seq(child, endpointsExpression) @@ -239,4 +219,31 @@ case class ApproxCountDistinctForIntervals( override def dataType: DataType = ArrayType(LongType) override def prettyName: String = "approx_count_distinct_for_intervals" + + override def serialize(obj: Array[Long]): Array[Byte] = { + val byteArray = new Array[Byte](obj.length * 8) + var i = 0 + while (i < obj.length) { + Platform.putLong(byteArray, Platform.BYTE_ARRAY_OFFSET + i * 8, obj(i)) + i += 1 + } + byteArray + } + + override def deserialize(bytes: Array[Byte]): Array[Long] = { + assert(bytes.length % 8 == 0) + val length = bytes.length / 8 + val longArray = new Array[Long](length) + var i = 0 + while (i < length) { + longArray(i) = Platform.getLong(bytes, Platform.BYTE_ARRAY_OFFSET + i * 8) + i += 1 + } + longArray + } + + private case class LongArrayInternalRow(array: Array[Long]) extends GenericInternalRow { + override def getLong(offset: Int): Long = array(offset) + override def setLong(offset: Int, value: Long): Unit = { array(offset) = value } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala index d6c38c3608bf8..73f18d4feef3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala @@ -32,7 +32,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { val wrongColumnTypes = Seq(BinaryType, BooleanType, StringType, ArrayType(IntegerType), MapType(IntegerType, IntegerType), StructType(Seq(StructField("s", IntegerType)))) wrongColumnTypes.foreach { dataType => - val wrongColumn = new ApproxCountDistinctForIntervals( + val wrongColumn = ApproxCountDistinctForIntervals( AttributeReference("a", dataType)(), endpointsExpression = CreateArray(Seq(1, 10).map(Literal(_)))) assert( @@ -43,7 +43,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { }) } - var wrongEndpoints = new ApproxCountDistinctForIntervals( + var wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = Literal(0.5d)) assert( @@ -52,19 +52,19 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { case _ => false }) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)()))) assert(wrongEndpoints.checkInputDataTypes() == TypeCheckFailure("The endpoints provided must be constant literals")) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array(10L).map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array("foobar").map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == @@ -75,25 +75,18 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { private def createEstimator[T]( endpoints: Array[T], dt: DataType, - rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, InternalRow) = { + rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, Array[Long]) = { val input = new SpecificInternalRow(Seq(dt)) val aggFunc = ApproxCountDistinctForIntervals( BoundReference(0, dt, nullable = true), CreateArray(endpoints.map(Literal(_))), rsd) - val buffer = createBuffer(aggFunc) - (aggFunc, input, buffer) - } - - private def createBuffer(aggFunc: ApproxCountDistinctForIntervals): InternalRow = { - val buffer = new SpecificInternalRow(aggFunc.aggBufferAttributes.map(_.dataType)) - aggFunc.initialize(buffer) - buffer + (aggFunc, input, aggFunc.createAggregationBuffer()) } test("merging ApproxCountDistinctForIntervals instances") { val (aggFunc, input, buffer1a) = createEstimator(Array[Int](0, 10, 2000, 345678, 1000000), IntegerType) - val buffer1b = createBuffer(aggFunc) - val buffer2 = createBuffer(aggFunc) + val buffer1b = aggFunc.createAggregationBuffer() + val buffer2 = aggFunc.createAggregationBuffer() // Add the lower half to `buffer1a`. var i = 0 @@ -123,7 +116,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { } // Check if the buffers are equal. - assert(buffer2 == buffer1a, "Buffers should be equal") + assert(buffer2.sameElements(buffer1a), "Buffers should be equal") } test("test findHllppIndex(value) for values in the range") { @@ -152,6 +145,13 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { checkHllppIndex(endpoints = Array(1, 3, 5, 7, 7, 9), value = 7, expectedIntervalIndex = 2) } + test("round trip serialization") { + val (aggFunc, _, _) = createEstimator(Array(1, 2), DoubleType) + val longArray = (1L to 100L).toArray + val roundtrip = aggFunc.deserialize(aggFunc.serialize(longArray)) + assert(roundtrip.sameElements(longArray)) + } + test("basic operations: update, merge, eval...") { val endpoints = Array[Double](0, 0.33, 0.6, 0.6, 0.6, 1.0) val data: Seq[Double] = Seq(0, 0.6, 0.3, 1, 0.6, 0.5, 0.6, 0.33) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala new file mode 100644 index 0000000000000..c7d86bc955d67 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.test.SharedSQLContext + +class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height + // histogram usually contains hundreds of buckets. So we need to test + // ApproxCountDistinctForIntervals with large number of endpoints + // (the number of endpoints == the number of buckets + 1). + test("test ApproxCountDistinctForIntervals with large number of endpoints") { + val table = "approx_count_distinct_for_intervals_tbl" + withTable(table) { + (1 to 100000).toDF("col").createOrReplaceTempView(table) + // percentiles of 0, 0.001, 0.002 ... 0.999, 1 + val endpoints = (0 to 1000).map(_ * 100000 / 1000) + + // Since approx_count_distinct_for_intervals is not a public function, here we do + // the computation by constructing logical plan. + val relation = spark.table(table).logicalPlan + val attr = relation.output.find(_.name == "col").get + val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_)))) + val aggExpr = aggFunc.toAggregateExpression() + val namedExpr = Alias(aggExpr, aggExpr.toString)() + val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation)) + .executedPlan.executeTake(1).head + val ndvArray = ndvsRow.getArray(0).toLongArray() + assert(endpoints.length == ndvArray.length + 1) + + // Each bucket has 100 distinct values. + val expectedNdv = 100 + for (i <- ndvArray.indices) { + val ndv = ndvArray(i) + val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) + assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.") + } + } + } +}