Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ import java.util
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExpectsInputTypes, Expression}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform

/**
* This function counts the approximate number of distinct values (ndv) in
Expand All @@ -46,16 +47,7 @@ case class ApproxCountDistinctForIntervals(
relativeSD: Double = 0.05,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends ImperativeAggregate with ExpectsInputTypes {

def this(child: Expression, endpointsExpression: Expression) = {
this(
child = child,
endpointsExpression = endpointsExpression,
relativeSD = 0.05,
mutableAggBufferOffset = 0,
inputAggBufferOffset = 0)
}
extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes {

def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = {
this(
Expand Down Expand Up @@ -114,29 +106,11 @@ case class ApproxCountDistinctForIntervals(
private lazy val totalNumWords = numWordsPerHllpp * hllppArray.length

/** Allocate enough words to store all registers. */
override lazy val aggBufferAttributes: Seq[AttributeReference] = {
Seq.tabulate(totalNumWords) { i =>
AttributeReference(s"MS[$i]", LongType)()
}
override def createAggregationBuffer(): Array[Long] = {
Array.fill(totalNumWords)(0L)
}

override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

// Note: although this simply copies aggBufferAttributes, this common code can not be placed
// in the superclass because that will lead to initialization ordering issues.
override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

/** Fill all words with zeros. */
override def initialize(buffer: InternalRow): Unit = {
var word = 0
while (word < totalNumWords) {
buffer.setLong(mutableAggBufferOffset + word, 0)
word += 1
}
}

override def update(buffer: InternalRow, input: InternalRow): Unit = {
override def update(buffer: Array[Long], input: InternalRow): Array[Long] = {
val value = child.eval(input)
// Ignore empty rows
if (value != null) {
Expand All @@ -153,13 +127,14 @@ case class ApproxCountDistinctForIntervals(
// endpoints are sorted into ascending order already
if (endpoints.head > doubleValue || endpoints.last < doubleValue) {
// ignore if the value is out of the whole range
return
return buffer
}

val hllppIndex = findHllppIndex(doubleValue)
val offset = mutableAggBufferOffset + hllppIndex * numWordsPerHllpp
hllppArray(hllppIndex).update(buffer, offset, value, child.dataType)
val offset = hllppIndex * numWordsPerHllpp
hllppArray(hllppIndex).update(LongArrayInternalRow(buffer), offset, value, child.dataType)
}
buffer
}

// Find which interval (HyperLogLogPlusPlusHelper) should receive the given value.
Expand Down Expand Up @@ -196,17 +171,18 @@ case class ApproxCountDistinctForIntervals(
}
}

override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
override def merge(buffer1: Array[Long], buffer2: Array[Long]): Array[Long] = {
for (i <- hllppArray.indices) {
hllppArray(i).merge(
buffer1 = buffer1,
buffer2 = buffer2,
offset1 = mutableAggBufferOffset + i * numWordsPerHllpp,
offset2 = inputAggBufferOffset + i * numWordsPerHllpp)
buffer1 = LongArrayInternalRow(buffer1),
buffer2 = LongArrayInternalRow(buffer2),
offset1 = i * numWordsPerHllpp,
offset2 = i * numWordsPerHllpp)
}
buffer1
}

override def eval(buffer: InternalRow): Any = {
override def eval(buffer: Array[Long]): Any = {
val ndvArray = hllppResults(buffer)
// If the endpoints contains multiple elements with the same value,
// we set ndv=1 for intervals between these elements.
Expand All @@ -218,19 +194,23 @@ case class ApproxCountDistinctForIntervals(
new GenericArrayData(ndvArray)
}

def hllppResults(buffer: InternalRow): Array[Long] = {
def hllppResults(buffer: Array[Long]): Array[Long] = {
val ndvArray = new Array[Long](hllppArray.length)
for (i <- ndvArray.indices) {
ndvArray(i) = hllppArray(i).query(buffer, mutableAggBufferOffset + i * numWordsPerHllpp)
ndvArray(i) = hllppArray(i).query(LongArrayInternalRow(buffer), i * numWordsPerHllpp)
}
ndvArray
}

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int)
: ApproxCountDistinctForIntervals = {
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
}

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int)
: ApproxCountDistinctForIntervals = {
copy(inputAggBufferOffset = newInputAggBufferOffset)
}

override def children: Seq[Expression] = Seq(child, endpointsExpression)

Expand All @@ -239,4 +219,31 @@ case class ApproxCountDistinctForIntervals(
override def dataType: DataType = ArrayType(LongType)

override def prettyName: String = "approx_count_distinct_for_intervals"

override def serialize(obj: Array[Long]): Array[Byte] = {
val byteArray = new Array[Byte](obj.length * 8)
var i = 0
while (i < obj.length) {
Platform.putLong(byteArray, Platform.BYTE_ARRAY_OFFSET + i * 8, obj(i))
i += 1
}
byteArray
}

override def deserialize(bytes: Array[Byte]): Array[Long] = {
assert(bytes.length % 8 == 0)
val length = bytes.length / 8
val longArray = new Array[Long](length)
var i = 0
while (i < length) {
longArray(i) = Platform.getLong(bytes, Platform.BYTE_ARRAY_OFFSET + i * 8)
i += 1
}
longArray
}

private case class LongArrayInternalRow(array: Array[Long]) extends GenericInternalRow {
override def getLong(offset: Int): Long = array(offset)
override def setLong(offset: Int, value: Long): Unit = { array(offset) = value }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
val wrongColumnTypes = Seq(BinaryType, BooleanType, StringType, ArrayType(IntegerType),
MapType(IntegerType, IntegerType), StructType(Seq(StructField("s", IntegerType))))
wrongColumnTypes.foreach { dataType =>
val wrongColumn = new ApproxCountDistinctForIntervals(
val wrongColumn = ApproxCountDistinctForIntervals(
AttributeReference("a", dataType)(),
endpointsExpression = CreateArray(Seq(1, 10).map(Literal(_))))
assert(
Expand All @@ -43,7 +43,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
})
}

var wrongEndpoints = new ApproxCountDistinctForIntervals(
var wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = Literal(0.5d))
assert(
Expand All @@ -52,19 +52,19 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
case _ => false
})

wrongEndpoints = new ApproxCountDistinctForIntervals(
wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)())))
assert(wrongEndpoints.checkInputDataTypes() ==
TypeCheckFailure("The endpoints provided must be constant literals"))

wrongEndpoints = new ApproxCountDistinctForIntervals(
wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Array(10L).map(Literal(_))))
assert(wrongEndpoints.checkInputDataTypes() ==
TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals"))

wrongEndpoints = new ApproxCountDistinctForIntervals(
wrongEndpoints = ApproxCountDistinctForIntervals(
AttributeReference("a", DoubleType)(),
endpointsExpression = CreateArray(Array("foobar").map(Literal(_))))
assert(wrongEndpoints.checkInputDataTypes() ==
Expand All @@ -75,25 +75,18 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
private def createEstimator[T](
endpoints: Array[T],
dt: DataType,
rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, InternalRow) = {
rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, Array[Long]) = {
val input = new SpecificInternalRow(Seq(dt))
val aggFunc = ApproxCountDistinctForIntervals(
BoundReference(0, dt, nullable = true), CreateArray(endpoints.map(Literal(_))), rsd)
val buffer = createBuffer(aggFunc)
(aggFunc, input, buffer)
}

private def createBuffer(aggFunc: ApproxCountDistinctForIntervals): InternalRow = {
val buffer = new SpecificInternalRow(aggFunc.aggBufferAttributes.map(_.dataType))
aggFunc.initialize(buffer)
buffer
(aggFunc, input, aggFunc.createAggregationBuffer())
}

test("merging ApproxCountDistinctForIntervals instances") {
val (aggFunc, input, buffer1a) =
createEstimator(Array[Int](0, 10, 2000, 345678, 1000000), IntegerType)
val buffer1b = createBuffer(aggFunc)
val buffer2 = createBuffer(aggFunc)
val buffer1b = aggFunc.createAggregationBuffer()
val buffer2 = aggFunc.createAggregationBuffer()

// Add the lower half to `buffer1a`.
var i = 0
Expand Down Expand Up @@ -123,7 +116,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
}

// Check if the buffers are equal.
assert(buffer2 == buffer1a, "Buffers should be equal")
assert(buffer2.sameElements(buffer1a), "Buffers should be equal")
}

test("test findHllppIndex(value) for values in the range") {
Expand Down Expand Up @@ -152,6 +145,13 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
checkHllppIndex(endpoints = Array(1, 3, 5, 7, 7, 9), value = 7, expectedIntervalIndex = 2)
}

test("round trip serialization") {
val (aggFunc, _, _) = createEstimator(Array(1, 2), DoubleType)
val longArray = (1L to 100L).toArray
val roundtrip = aggFunc.deserialize(aggFunc.serialize(longArray))
assert(roundtrip.sameElements(longArray))
}

test("basic operations: update, merge, eval...") {
val endpoints = Array[Double](0, 0.33, 0.6, 0.6, 0.6, 1.0)
val data: Seq[Double] = Seq(0, 0.6, 0.3, 1, 0.6, 0.5, 0.6, 0.33)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.test.SharedSQLContext

class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext {
import testImplicits._

// ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height
// histogram usually contains hundreds of buckets. So we need to test
// ApproxCountDistinctForIntervals with large number of endpoints
// (the number of endpoints == the number of buckets + 1).
test("test ApproxCountDistinctForIntervals with large number of endpoints") {
val table = "approx_count_distinct_for_intervals_tbl"
withTable(table) {
(1 to 100000).toDF("col").createOrReplaceTempView(table)
// percentiles of 0, 0.001, 0.002 ... 0.999, 1
val endpoints = (0 to 1000).map(_ * 100000 / 1000)

// Since approx_count_distinct_for_intervals is not a public function, here we do
// the computation by constructing logical plan.
val relation = spark.table(table).logicalPlan
val attr = relation.output.find(_.name == "col").get
val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_))))
val aggExpr = aggFunc.toAggregateExpression()
val namedExpr = Alias(aggExpr, aggExpr.toString)()
val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation))
.executedPlan.executeTake(1).head
val ndvArray = ndvsRow.getArray(0).toLongArray()
assert(endpoints.length == ndvArray.length + 1)

// Each bucket has 100 distinct values.
val expectedNdv = 100
for (i <- ndvArray.indices) {
val ndv = ndvArray(i)
val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.")
}
}
}
}