Skip to content

Commit 1ebd41b

Browse files
yhuairxin
authored andcommitted
[SPARK-9240] [SQL] Hybrid aggregate operator using unsafe row
This PR adds a base aggregation iterator `AggregationIterator`, which is used to create `SortBasedAggregationIterator` (for sort-based aggregation) and `UnsafeHybridAggregationIterator` (first it tries hash-based aggregation and falls back to the sort-based aggregation (using external sorter) if we cannot allocate memory for the map). With these two iterators, we will not need existing iterators and I am removing those. Also, we can use a single physical `Aggregate` operator and it internally determines what iterators to used. https://issues.apache.org/jira/browse/SPARK-9240 Author: Yin Huai <yhuai@databricks.com> Closes #7813 from yhuai/AggregateOperator and squashes the following commits: e317e2b [Yin Huai] Remove unnecessary change. 74d93c5 [Yin Huai] Merge remote-tracking branch 'upstream/master' into AggregateOperator ba6afbc [Yin Huai] Add a little bit more comments. c9cf3b6 [Yin Huai] update 0f1b06f [Yin Huai] Remove unnecessary code. 21fd15f [Yin Huai] Remove unnecessary change. 964f88b [Yin Huai] Implement fallback strategy. b1ea5cf [Yin Huai] wip 7fcbd87 [Yin Huai] Add a flag to control what iterator to use. 533d5b2 [Yin Huai] Prepare for fallback! 33b7022 [Yin Huai] wip bd9282b [Yin Huai] UDAFs now supports UnsafeRow. f52ee53 [Yin Huai] wip 3171f44 [Yin Huai] wip d2c45a0 [Yin Huai] wip f60cc83 [Yin Huai] Also check input schema. af32210 [Yin Huai] Check iter.hasNext before we create an iterator because the constructor of the iterato will read at least one row from a non-empty input iter. 299008c [Yin Huai] First round cleanup. 3915bac [Yin Huai] Create a base iterator class for aggregation iterators and add the initial version of the hybrid iterator.
1 parent 98d6d9c commit 1ebd41b

File tree

13 files changed

+1697
-973
lines changed

13 files changed

+1697
-973
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,11 @@ abstract class AggregateFunction2
110110
* buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)`
111111
* will be 2.
112112
*/
113-
var mutableBufferOffset: Int = 0
113+
protected var mutableBufferOffset: Int = 0
114+
115+
def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = {
116+
mutableBufferOffset = newMutableBufferOffset
117+
}
114118

115119
/**
116120
* The offset of this function's start buffer value in the
@@ -126,7 +130,11 @@ abstract class AggregateFunction2
126130
* buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)`
127131
* will be 3 (position 0 is used for the value of key`).
128132
*/
129-
var inputBufferOffset: Int = 0
133+
protected var inputBufferOffset: Int = 0
134+
135+
def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = {
136+
inputBufferOffset = newInputBufferOffset
137+
}
130138

131139
/** The schema of the aggregation buffer. */
132140
def bufferSchema: StructType
@@ -195,11 +203,8 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w
195203
override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
196204

197205
override def initialize(buffer: MutableRow): Unit = {
198-
var i = 0
199-
while (i < bufferAttributes.size) {
200-
buffer(i + mutableBufferOffset) = initialValues(i).eval()
201-
i += 1
202-
}
206+
throw new UnsupportedOperationException(
207+
"AlgebraicAggregate's initialize should not be called directly")
203208
}
204209

205210
override final def update(buffer: MutableRow, input: InternalRow): Unit = {
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.aggregate
19+
20+
import org.apache.spark.rdd.RDD
21+
import org.apache.spark.sql.catalyst.errors._
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.aggregate._
25+
import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
26+
import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
27+
import org.apache.spark.sql.types.StructType
28+
29+
/**
30+
* An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the data types
31+
* of the grouping expressions and aggregate functions, it determines if it uses
32+
* sort-based aggregation and hybrid (hash-based with sort-based as the fallback) to
33+
* process input rows.
34+
*/
35+
case class Aggregate(
36+
requiredChildDistributionExpressions: Option[Seq[Expression]],
37+
groupingExpressions: Seq[NamedExpression],
38+
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
39+
nonCompleteAggregateAttributes: Seq[Attribute],
40+
completeAggregateExpressions: Seq[AggregateExpression2],
41+
completeAggregateAttributes: Seq[Attribute],
42+
initialInputBufferOffset: Int,
43+
resultExpressions: Seq[NamedExpression],
44+
child: SparkPlan)
45+
extends UnaryNode {
46+
47+
private[this] val allAggregateExpressions =
48+
nonCompleteAggregateExpressions ++ completeAggregateExpressions
49+
50+
private[this] val hasNonAlgebricAggregateFunctions =
51+
!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])
52+
53+
// Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of
54+
// grouping key and aggregation buffer is supported; and (3) all
55+
// aggregate functions are algebraic.
56+
private[this] val supportsHybridIterator: Boolean = {
57+
val aggregationBufferSchema: StructType =
58+
StructType.fromAttributes(
59+
allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
60+
val groupKeySchema: StructType =
61+
StructType.fromAttributes(groupingExpressions.map(_.toAttribute))
62+
63+
val schemaSupportsUnsafe: Boolean =
64+
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
65+
UnsafeProjection.canSupport(groupKeySchema)
66+
67+
// TODO: Use the hybrid iterator for non-algebric aggregate functions.
68+
sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions
69+
}
70+
71+
// We need to use sorted input if we have grouping expressions, and
72+
// we cannot use the hybrid iterator or the hybrid is disabled.
73+
private[this] val requiresSortedInput: Boolean = {
74+
groupingExpressions.nonEmpty && !supportsHybridIterator
75+
}
76+
77+
override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions
78+
79+
// If result expressions' data types are all fixed length, we generate unsafe rows
80+
// (We have this requirement instead of check the result of UnsafeProjection.canSupport
81+
// is because we use a mutable projection to generate the result).
82+
override def outputsUnsafeRows: Boolean = {
83+
// resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength)
84+
// TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix
85+
// any issue we get.
86+
false
87+
}
88+
89+
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
90+
91+
override def requiredChildDistribution: List[Distribution] = {
92+
requiredChildDistributionExpressions match {
93+
case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
94+
case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
95+
case None => UnspecifiedDistribution :: Nil
96+
}
97+
}
98+
99+
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
100+
if (requiresSortedInput) {
101+
// TODO: We should not sort the input rows if they are just in reversed order.
102+
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
103+
} else {
104+
Seq.fill(children.size)(Nil)
105+
}
106+
}
107+
108+
override def outputOrdering: Seq[SortOrder] = {
109+
if (requiresSortedInput) {
110+
// It is possible that the child.outputOrdering starts with the required
111+
// ordering expressions (e.g. we require [a] as the sort expression and the
112+
// child's outputOrdering is [a, b]). We can only guarantee the output rows
113+
// are sorted by values of groupingExpressions.
114+
groupingExpressions.map(SortOrder(_, Ascending))
115+
} else {
116+
Nil
117+
}
118+
}
119+
120+
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
121+
child.execute().mapPartitions { iter =>
122+
// Because the constructor of an aggregation iterator will read at least the first row,
123+
// we need to get the value of iter.hasNext first.
124+
val hasInput = iter.hasNext
125+
val useHybridIterator =
126+
hasInput &&
127+
supportsHybridIterator &&
128+
groupingExpressions.nonEmpty
129+
if (useHybridIterator) {
130+
UnsafeHybridAggregationIterator.createFromInputIterator(
131+
groupingExpressions,
132+
nonCompleteAggregateExpressions,
133+
nonCompleteAggregateAttributes,
134+
completeAggregateExpressions,
135+
completeAggregateAttributes,
136+
initialInputBufferOffset,
137+
resultExpressions,
138+
newMutableProjection _,
139+
child.output,
140+
iter,
141+
outputsUnsafeRows)
142+
} else {
143+
if (!hasInput && groupingExpressions.nonEmpty) {
144+
// This is a grouped aggregate and the input iterator is empty,
145+
// so return an empty iterator.
146+
Iterator[InternalRow]()
147+
} else {
148+
val outputIter = SortBasedAggregationIterator.createFromInputIterator(
149+
groupingExpressions,
150+
nonCompleteAggregateExpressions,
151+
nonCompleteAggregateAttributes,
152+
completeAggregateExpressions,
153+
completeAggregateAttributes,
154+
initialInputBufferOffset,
155+
resultExpressions,
156+
newMutableProjection _ ,
157+
newProjection _,
158+
child.output,
159+
iter,
160+
outputsUnsafeRows)
161+
if (!hasInput && groupingExpressions.isEmpty) {
162+
// There is no input and there is no grouping expressions.
163+
// We need to output a single row as the output.
164+
Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
165+
} else {
166+
outputIter
167+
}
168+
}
169+
}
170+
}
171+
}
172+
173+
override def simpleString: String = {
174+
val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) {
175+
classOf[UnsafeHybridAggregationIterator].getSimpleName
176+
} else {
177+
classOf[SortBasedAggregationIterator].getSimpleName
178+
}
179+
180+
s"""NewAggregate with $iterator ${groupingExpressions} ${allAggregateExpressions}"""
181+
}
182+
}

0 commit comments

Comments
 (0)