Skip to content

Commit 964f88b

Browse files
committed
Implement fallback strategy.
1 parent b1ea5cf commit 964f88b

File tree

9 files changed

+174
-176
lines changed

9 files changed

+174
-176
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,11 @@ abstract class AggregateFunction2
110110
* buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)`
111111
* will be 2.
112112
*/
113-
var mutableBufferOffset: Int = 0
113+
protected var mutableBufferOffset: Int = 0
114+
115+
def withNewMutableBufferOffset(newMutableBufferOffset: Int): Unit = {
116+
mutableBufferOffset = newMutableBufferOffset
117+
}
114118

115119
/**
116120
* The offset of this function's start buffer value in the
@@ -126,7 +130,11 @@ abstract class AggregateFunction2
126130
* buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)`
127131
* will be 3 (position 0 is used for the value of key`).
128132
*/
129-
var inputBufferOffset: Int = 0
133+
protected var inputBufferOffset: Int = 0
134+
135+
def withNewInputBufferOffset(newInputBufferOffset: Int): Unit = {
136+
inputBufferOffset = newInputBufferOffset
137+
}
130138

131139
/** The schema of the aggregation buffer. */
132140
def bufferSchema: StructType
@@ -195,11 +203,8 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w
195203
override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
196204

197205
override def initialize(buffer: MutableRow): Unit = {
198-
var i = 0
199-
while (i < bufferAttributes.size) {
200-
buffer(i + mutableBufferOffset) = initialValues(i).eval()
201-
i += 1
202-
}
206+
throw new UnsupportedOperationException(
207+
"AlgebraicAggregate's initialize should not be called directly")
203208
}
204209

205210
override final def update(buffer: MutableRow, input: InternalRow): Unit = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
137137
}
138138

139139
override def keyExpressions: Seq[Expression] = expressions
140+
141+
override def toString: String = s"${super.toString} numPartitions=$numPartitions"
140142
}
141143

142144
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ case class Aggregate(
4747
private[this] val allAggregateExpressions =
4848
nonCompleteAggregateExpressions ++ completeAggregateExpressions
4949

50+
private[this] val hasNonAlgebricAggregateFunctions =
51+
!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])
52+
5053
// Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of
5154
// grouping key and aggregation buffer is supported; and (3) all
5255
// aggregate functions are algebraic.
@@ -56,14 +59,13 @@ case class Aggregate(
5659
allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
5760
val groupKeySchema: StructType =
5861
StructType.fromAttributes(groupingExpressions.map(_.toAttribute))
59-
val resultSchema: StructType =
60-
StructType.fromAttributes(resultExpressions.map(_.toAttribute))
62+
6163
val schemaSupportsUnsafe: Boolean =
6264
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
63-
UnsafeProjection.canSupport(groupKeySchema) &&
64-
UnsafeProjection.canSupport(resultSchema)
65+
UnsafeProjection.canSupport(groupKeySchema)
6566

66-
sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe
67+
// TODO: Use the hybrid iterator for non-algebric aggregate functions.
68+
sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && !hasNonAlgebricAggregateFunctions
6769
}
6870

6971
private[this] val hybridAggregateEnabled = sqlContext.conf.useHybridAggregate
@@ -74,9 +76,17 @@ case class Aggregate(
7476
groupingExpressions.nonEmpty && (!supportsHybridIterator || !hybridAggregateEnabled)
7577
}
7678

77-
override def canProcessUnsafeRows: Boolean = false // true
79+
override def canProcessUnsafeRows: Boolean = !hasNonAlgebricAggregateFunctions
7880

79-
override def outputsUnsafeRows: Boolean = supportsHybridIterator
81+
// If result expressions' data types are all fixed length, we generate unsafe rows
82+
// (We have this requirement instead of check the result of UnsafeProjection.canSupport
83+
// is because we use a mutable projection to generate the result).
84+
override def outputsUnsafeRows: Boolean = {
85+
// resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength)
86+
// TODO: Supports generating UnsafeRows. We can just re-enable the line above and fix
87+
// any issue we get.
88+
false
89+
}
8090

8191
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
8292

@@ -131,7 +141,7 @@ case class Aggregate(
131141
newMutableProjection _,
132142
child.output,
133143
iter,
134-
child.outputsUnsafeRows)
144+
outputsUnsafeRows)
135145
} else {
136146
if (!hasInput && groupingExpressions.nonEmpty) {
137147
// This is a grouped aggregate and the input iterator is empty,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,14 @@ abstract class AggregationIterator(
9090
case _ =>
9191
// We only need to set inputBufferOffset for aggregate functions with mode
9292
// PartialMerge and Final.
93-
func.inputBufferOffset = inputBufferOffset
93+
func.withNewInputBufferOffset(inputBufferOffset)
9494
inputBufferOffset += func.bufferSchema.length
9595
func
9696
}
9797
// Set mutableBufferOffset for this function. It is important that setting
9898
// mutableBufferOffset happens after all potential bindReference operations
9999
// because bindReference will create a new instance of the function.
100-
funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset
100+
funcWithBoundReferences.withNewMutableBufferOffset(mutableBufferOffset)
101101
mutableBufferOffset += funcWithBoundReferences.bufferSchema.length
102102
functions(i) = funcWithBoundReferences
103103
i += 1
@@ -347,7 +347,8 @@ abstract class AggregationIterator(
347347
}
348348
val algebraicEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)()
349349
val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes
350-
val aggregateResult: MutableRow = new GenericMutableRow(aggregateResultSchema.length)
350+
// TODO: Use unsafe row.
351+
val aggregateResult = new GenericMutableRow(aggregateResultSchema.length)
351352
val resultProjection =
352353
newMutableProjection(
353354
resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)()

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ class SortBasedAggregationIterator(
5959
val bufferRowSize: Int = bufferSchema.length
6060

6161
val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
62-
val buffer = if (outputsUnsafeRows) {
62+
val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isFixedLength)
63+
64+
val buffer = if (useUnsafeBuffer) {
6365
val unsafeProjection =
6466
UnsafeProjection.create(bufferSchema.map(_.dataType))
6567
unsafeProjection.apply(genericMutableBuffer)
@@ -70,7 +72,6 @@ class SortBasedAggregationIterator(
7072
buffer
7173
}
7274

73-
7475
///////////////////////////////////////////////////////////////////////////
7576
// Mutable states for sort based aggregation.
7677
///////////////////////////////////////////////////////////////////////////
@@ -96,8 +97,6 @@ class SortBasedAggregationIterator(
9697
// Now, we will start to find all rows belonging to this group.
9798
// We create a variable to track if we see the next group.
9899
var findNextPartition = false
99-
println("sort key first in group " + toSafeKey(nextGroupingKey) + " value " + toSafeValue(firstRowInNextGroup) + " buffer " + toSafeBuffer(sortBasedAggregationBuffer))
100-
101100
// firstRowInNextGroup is the first row of this group. We first process it.
102101
processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
103102

@@ -112,7 +111,6 @@ class SortBasedAggregationIterator(
112111
// Check if the current row belongs the current input row.
113112
if (currentGroupingKey == groupingKey) {
114113
processRow(sortBasedAggregationBuffer, currentRow)
115-
println("sort key " + toSafeKey(groupingKey) + " value " + toSafeValue(currentRow) + " buffer " + toSafeBuffer(sortBasedAggregationBuffer))
116114

117115
hasNext = inputKVIterator.next()
118116
} else {
@@ -143,7 +141,6 @@ class SortBasedAggregationIterator(
143141
processCurrentSortedGroup()
144142
// Generate output row for the current group.
145143
val outputRow = generateOutput(currentGroupingKey, sortBasedAggregationBuffer)
146-
println("sort result " + toSafeResult(outputRow))
147144
// Initialize buffer values for the next group.
148145
initializeBuffer(sortBasedAggregationBuffer)
149146

@@ -154,18 +151,12 @@ class SortBasedAggregationIterator(
154151
}
155152
}
156153

157-
val toSafeKey = FromUnsafeProjection(groupingKeyAttributes.map(_.dataType).toArray)
158-
val toSafeValue = FromUnsafeProjection(valueAttributes.map(_.dataType).toArray)
159-
val toSafeBuffer = FromUnsafeProjection(allAggregateFunctions.flatMap(_.bufferAttributes).map(_.dataType).toArray)
160-
161154
protected def initialize(): Unit = {
162155
if (inputKVIterator.next()) {
163156
initializeBuffer(sortBasedAggregationBuffer)
164-
println("first " + toSafeKey(inputKVIterator.getKey()) + " value " + toSafeValue(inputKVIterator.getValue()) + " buffer " + toSafeBuffer(sortBasedAggregationBuffer))
165157

166158
nextGroupingKey = inputKVIterator.getKey().copy()
167159
firstRowInNextGroup = inputKVIterator.getValue().copy()
168-
println("first " + toSafeKey(nextGroupingKey) + " value " + toSafeValue(firstRowInNextGroup) + " buffer " + toSafeBuffer(sortBasedAggregationBuffer))
169160

170161
sortedInputHasNewGroup = true
171162
} else {
@@ -183,6 +174,7 @@ class SortBasedAggregationIterator(
183174
}
184175

185176
object SortBasedAggregationIterator {
177+
// scalastyle:off
186178
def createFromInputIterator(
187179
groupingExprs: Seq[NamedExpression],
188180
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
@@ -196,7 +188,7 @@ object SortBasedAggregationIterator {
196188
inputAttributes: Seq[Attribute],
197189
inputIter: Iterator[InternalRow],
198190
outputsUnsafeRows: Boolean): SortBasedAggregationIterator = {
199-
val kvIterator = if (outputsUnsafeRows) {
191+
val kvIterator = if (UnsafeProjection.canSupport(groupingExprs)) {
200192
AggregationIterator.unsafeKVIterator(
201193
groupingExprs,
202194
inputAttributes,
@@ -244,4 +236,5 @@ object SortBasedAggregationIterator {
244236
newMutableProjection,
245237
outputsUnsafeRows)
246238
}
239+
// scalastyle:on
247240
}

0 commit comments

Comments
 (0)