Skip to content

Commit f6290ae

Browse files
Zhenhua Wangcloud-fan
authored andcommitted
[SPARK-22285][SQL] Change implementation of ApproxCountDistinctForIntervals to TypedImperativeAggregate
## What changes were proposed in this pull request? The current implementation of `ApproxCountDistinctForIntervals` is `ImperativeAggregate`. The number of `aggBufferAttributes` is the number of total words in the hllppHelper array. Each hllppHelper has 52 words by default relativeSD. Since this aggregate function is used in equi-height histogram generation, and the number of buckets in histogram is usually hundreds, the number of `aggBufferAttributes` can easily reach tens of thousands or even more. This leads to a huge method in codegen and causes error: ``` org.codehaus.janino.JaninoRuntimeException: Code of method "apply(Lorg/apache/spark/sql/catalyst/InternalRow;)Lorg/apache/spark/sql/catalyst/expressions/UnsafeRow;" of class "org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection" grows beyond 64 KB. ``` Besides, huge generated methods also result in performance regression. In this PR, we change its implementation to `TypedImperativeAggregate`. After the fix, `ApproxCountDistinctForIntervals` can deal with more than thousands endpoints without throwing codegen error, and improve performance from `20 sec` to `2 sec` in a test case of 500 endpoints. ## How was this patch tested? Test by an added test case and existing tests. Author: Zhenhua Wang <wangzhenhua@huawei.com> Closes #19506 from wzhfy/change_forIntervals_typedAgg.
1 parent 5a5b6b7 commit f6290ae

File tree

3 files changed

+130
-62
lines changed

3 files changed

+130
-62
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ import java.util
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2424
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
25-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExpectsInputTypes, Expression}
25+
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow}
2626
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper}
2727
import org.apache.spark.sql.types._
28+
import org.apache.spark.unsafe.Platform
2829

2930
/**
3031
* This function counts the approximate number of distinct values (ndv) in
@@ -46,16 +47,7 @@ case class ApproxCountDistinctForIntervals(
4647
relativeSD: Double = 0.05,
4748
mutableAggBufferOffset: Int = 0,
4849
inputAggBufferOffset: Int = 0)
49-
extends ImperativeAggregate with ExpectsInputTypes {
50-
51-
def this(child: Expression, endpointsExpression: Expression) = {
52-
this(
53-
child = child,
54-
endpointsExpression = endpointsExpression,
55-
relativeSD = 0.05,
56-
mutableAggBufferOffset = 0,
57-
inputAggBufferOffset = 0)
58-
}
50+
extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes {
5951

6052
def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = {
6153
this(
@@ -114,29 +106,11 @@ case class ApproxCountDistinctForIntervals(
114106
private lazy val totalNumWords = numWordsPerHllpp * hllppArray.length
115107

116108
/** Allocate enough words to store all registers. */
117-
override lazy val aggBufferAttributes: Seq[AttributeReference] = {
118-
Seq.tabulate(totalNumWords) { i =>
119-
AttributeReference(s"MS[$i]", LongType)()
120-
}
109+
override def createAggregationBuffer(): Array[Long] = {
110+
Array.fill(totalNumWords)(0L)
121111
}
122112

123-
override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
124-
125-
// Note: although this simply copies aggBufferAttributes, this common code can not be placed
126-
// in the superclass because that will lead to initialization ordering issues.
127-
override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
128-
aggBufferAttributes.map(_.newInstance())
129-
130-
/** Fill all words with zeros. */
131-
override def initialize(buffer: InternalRow): Unit = {
132-
var word = 0
133-
while (word < totalNumWords) {
134-
buffer.setLong(mutableAggBufferOffset + word, 0)
135-
word += 1
136-
}
137-
}
138-
139-
override def update(buffer: InternalRow, input: InternalRow): Unit = {
113+
override def update(buffer: Array[Long], input: InternalRow): Array[Long] = {
140114
val value = child.eval(input)
141115
// Ignore empty rows
142116
if (value != null) {
@@ -153,13 +127,14 @@ case class ApproxCountDistinctForIntervals(
153127
// endpoints are sorted into ascending order already
154128
if (endpoints.head > doubleValue || endpoints.last < doubleValue) {
155129
// ignore if the value is out of the whole range
156-
return
130+
return buffer
157131
}
158132

159133
val hllppIndex = findHllppIndex(doubleValue)
160-
val offset = mutableAggBufferOffset + hllppIndex * numWordsPerHllpp
161-
hllppArray(hllppIndex).update(buffer, offset, value, child.dataType)
134+
val offset = hllppIndex * numWordsPerHllpp
135+
hllppArray(hllppIndex).update(LongArrayInternalRow(buffer), offset, value, child.dataType)
162136
}
137+
buffer
163138
}
164139

165140
// Find which interval (HyperLogLogPlusPlusHelper) should receive the given value.
@@ -196,17 +171,18 @@ case class ApproxCountDistinctForIntervals(
196171
}
197172
}
198173

199-
override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
174+
override def merge(buffer1: Array[Long], buffer2: Array[Long]): Array[Long] = {
200175
for (i <- hllppArray.indices) {
201176
hllppArray(i).merge(
202-
buffer1 = buffer1,
203-
buffer2 = buffer2,
204-
offset1 = mutableAggBufferOffset + i * numWordsPerHllpp,
205-
offset2 = inputAggBufferOffset + i * numWordsPerHllpp)
177+
buffer1 = LongArrayInternalRow(buffer1),
178+
buffer2 = LongArrayInternalRow(buffer2),
179+
offset1 = i * numWordsPerHllpp,
180+
offset2 = i * numWordsPerHllpp)
206181
}
182+
buffer1
207183
}
208184

209-
override def eval(buffer: InternalRow): Any = {
185+
override def eval(buffer: Array[Long]): Any = {
210186
val ndvArray = hllppResults(buffer)
211187
// If the endpoints contains multiple elements with the same value,
212188
// we set ndv=1 for intervals between these elements.
@@ -218,19 +194,23 @@ case class ApproxCountDistinctForIntervals(
218194
new GenericArrayData(ndvArray)
219195
}
220196

221-
def hllppResults(buffer: InternalRow): Array[Long] = {
197+
def hllppResults(buffer: Array[Long]): Array[Long] = {
222198
val ndvArray = new Array[Long](hllppArray.length)
223199
for (i <- ndvArray.indices) {
224-
ndvArray(i) = hllppArray(i).query(buffer, mutableAggBufferOffset + i * numWordsPerHllpp)
200+
ndvArray(i) = hllppArray(i).query(LongArrayInternalRow(buffer), i * numWordsPerHllpp)
225201
}
226202
ndvArray
227203
}
228204

229-
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
205+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int)
206+
: ApproxCountDistinctForIntervals = {
230207
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
208+
}
231209

232-
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
210+
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int)
211+
: ApproxCountDistinctForIntervals = {
233212
copy(inputAggBufferOffset = newInputAggBufferOffset)
213+
}
234214

235215
override def children: Seq[Expression] = Seq(child, endpointsExpression)
236216

@@ -239,4 +219,31 @@ case class ApproxCountDistinctForIntervals(
239219
override def dataType: DataType = ArrayType(LongType)
240220

241221
override def prettyName: String = "approx_count_distinct_for_intervals"
222+
223+
override def serialize(obj: Array[Long]): Array[Byte] = {
224+
val byteArray = new Array[Byte](obj.length * 8)
225+
var i = 0
226+
while (i < obj.length) {
227+
Platform.putLong(byteArray, Platform.BYTE_ARRAY_OFFSET + i * 8, obj(i))
228+
i += 1
229+
}
230+
byteArray
231+
}
232+
233+
override def deserialize(bytes: Array[Byte]): Array[Long] = {
234+
assert(bytes.length % 8 == 0)
235+
val length = bytes.length / 8
236+
val longArray = new Array[Long](length)
237+
var i = 0
238+
while (i < length) {
239+
longArray(i) = Platform.getLong(bytes, Platform.BYTE_ARRAY_OFFSET + i * 8)
240+
i += 1
241+
}
242+
longArray
243+
}
244+
245+
private case class LongArrayInternalRow(array: Array[Long]) extends GenericInternalRow {
246+
override def getLong(offset: Int): Long = array(offset)
247+
override def setLong(offset: Int, value: Long): Unit = { array(offset) = value }
248+
}
242249
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
3232
val wrongColumnTypes = Seq(BinaryType, BooleanType, StringType, ArrayType(IntegerType),
3333
MapType(IntegerType, IntegerType), StructType(Seq(StructField("s", IntegerType))))
3434
wrongColumnTypes.foreach { dataType =>
35-
val wrongColumn = new ApproxCountDistinctForIntervals(
35+
val wrongColumn = ApproxCountDistinctForIntervals(
3636
AttributeReference("a", dataType)(),
3737
endpointsExpression = CreateArray(Seq(1, 10).map(Literal(_))))
3838
assert(
@@ -43,7 +43,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
4343
})
4444
}
4545

46-
var wrongEndpoints = new ApproxCountDistinctForIntervals(
46+
var wrongEndpoints = ApproxCountDistinctForIntervals(
4747
AttributeReference("a", DoubleType)(),
4848
endpointsExpression = Literal(0.5d))
4949
assert(
@@ -52,19 +52,19 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
5252
case _ => false
5353
})
5454

55-
wrongEndpoints = new ApproxCountDistinctForIntervals(
55+
wrongEndpoints = ApproxCountDistinctForIntervals(
5656
AttributeReference("a", DoubleType)(),
5757
endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)())))
5858
assert(wrongEndpoints.checkInputDataTypes() ==
5959
TypeCheckFailure("The endpoints provided must be constant literals"))
6060

61-
wrongEndpoints = new ApproxCountDistinctForIntervals(
61+
wrongEndpoints = ApproxCountDistinctForIntervals(
6262
AttributeReference("a", DoubleType)(),
6363
endpointsExpression = CreateArray(Array(10L).map(Literal(_))))
6464
assert(wrongEndpoints.checkInputDataTypes() ==
6565
TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals"))
6666

67-
wrongEndpoints = new ApproxCountDistinctForIntervals(
67+
wrongEndpoints = ApproxCountDistinctForIntervals(
6868
AttributeReference("a", DoubleType)(),
6969
endpointsExpression = CreateArray(Array("foobar").map(Literal(_))))
7070
assert(wrongEndpoints.checkInputDataTypes() ==
@@ -75,25 +75,18 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
7575
private def createEstimator[T](
7676
endpoints: Array[T],
7777
dt: DataType,
78-
rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, InternalRow) = {
78+
rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, Array[Long]) = {
7979
val input = new SpecificInternalRow(Seq(dt))
8080
val aggFunc = ApproxCountDistinctForIntervals(
8181
BoundReference(0, dt, nullable = true), CreateArray(endpoints.map(Literal(_))), rsd)
82-
val buffer = createBuffer(aggFunc)
83-
(aggFunc, input, buffer)
84-
}
85-
86-
private def createBuffer(aggFunc: ApproxCountDistinctForIntervals): InternalRow = {
87-
val buffer = new SpecificInternalRow(aggFunc.aggBufferAttributes.map(_.dataType))
88-
aggFunc.initialize(buffer)
89-
buffer
82+
(aggFunc, input, aggFunc.createAggregationBuffer())
9083
}
9184

9285
test("merging ApproxCountDistinctForIntervals instances") {
9386
val (aggFunc, input, buffer1a) =
9487
createEstimator(Array[Int](0, 10, 2000, 345678, 1000000), IntegerType)
95-
val buffer1b = createBuffer(aggFunc)
96-
val buffer2 = createBuffer(aggFunc)
88+
val buffer1b = aggFunc.createAggregationBuffer()
89+
val buffer2 = aggFunc.createAggregationBuffer()
9790

9891
// Add the lower half to `buffer1a`.
9992
var i = 0
@@ -123,7 +116,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
123116
}
124117

125118
// Check if the buffers are equal.
126-
assert(buffer2 == buffer1a, "Buffers should be equal")
119+
assert(buffer2.sameElements(buffer1a), "Buffers should be equal")
127120
}
128121

129122
test("test findHllppIndex(value) for values in the range") {
@@ -152,6 +145,13 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite {
152145
checkHllppIndex(endpoints = Array(1, 3, 5, 7, 7, 9), value = 7, expectedIntervalIndex = 2)
153146
}
154147

148+
test("round trip serialization") {
149+
val (aggFunc, _, _) = createEstimator(Array(1, 2), DoubleType)
150+
val longArray = (1L to 100L).toArray
151+
val roundtrip = aggFunc.deserialize(aggFunc.serialize(longArray))
152+
assert(roundtrip.sameElements(longArray))
153+
}
154+
155155
test("basic operations: update, merge, eval...") {
156156
val endpoints = Array[Double](0, 0.33, 0.6, 0.6, 0.6, 1.0)
157157
val data: Seq[Double] = Seq(0, 0.6, 0.3, 1, 0.6, 0.5, 0.6, 0.33)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
21+
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
22+
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
23+
import org.apache.spark.sql.execution.QueryExecution
24+
import org.apache.spark.sql.test.SharedSQLContext
25+
26+
class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext {
27+
import testImplicits._
28+
29+
// ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height
30+
// histogram usually contains hundreds of buckets. So we need to test
31+
// ApproxCountDistinctForIntervals with large number of endpoints
32+
// (the number of endpoints == the number of buckets + 1).
33+
test("test ApproxCountDistinctForIntervals with large number of endpoints") {
34+
val table = "approx_count_distinct_for_intervals_tbl"
35+
withTable(table) {
36+
(1 to 100000).toDF("col").createOrReplaceTempView(table)
37+
// percentiles of 0, 0.001, 0.002 ... 0.999, 1
38+
val endpoints = (0 to 1000).map(_ * 100000 / 1000)
39+
40+
// Since approx_count_distinct_for_intervals is not a public function, here we do
41+
// the computation by constructing logical plan.
42+
val relation = spark.table(table).logicalPlan
43+
val attr = relation.output.find(_.name == "col").get
44+
val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_))))
45+
val aggExpr = aggFunc.toAggregateExpression()
46+
val namedExpr = Alias(aggExpr, aggExpr.toString)()
47+
val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation))
48+
.executedPlan.executeTake(1).head
49+
val ndvArray = ndvsRow.getArray(0).toLongArray()
50+
assert(endpoints.length == ndvArray.length + 1)
51+
52+
// Each bucket has 100 distinct values.
53+
val expectedNdv = 100
54+
for (i <- ndvArray.indices) {
55+
val ndv = ndvArray(i)
56+
val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
57+
assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.")
58+
}
59+
}
60+
}
61+
}

0 commit comments

Comments
 (0)