Skip to content

Commit 539720d

Browse files
committed
Support two level hash map for final hash aggregation
1 parent 9a6d773 commit 539720d

File tree

1 file changed

+47
-4
lines changed

1 file changed

+47
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ case class HashAggregateExec(
128128
// all the mode of aggregate expressions
129129
private val modes = aggregateExpressions.map(_.mode).distinct
130130

131+
// This is for testing final aggregate with number-of-rows-based fall back as specified in
132+
// `testFallbackStartsAt`. In this scenario, there might be same keys exist in both fast and
133+
// regular hash map. So the aggregation buffers from both maps need to be merged together
134+
// to avoid correctness issue.
135+
//
136+
// This scenario only happens in unit test with number-of-rows-based fall back.
137+
// There should not be same keys in both maps with size-based fall back in production.
138+
private val isTestFinalAggregateWithFallback: Boolean = testFallbackStartsAt.isDefined &&
139+
(modes.contains(Final) || modes.contains(Complete))
140+
131141
override def usedInputs: AttributeSet = inputSet
132142

133143
override def supportCodegen: Boolean = {
@@ -537,6 +547,34 @@ case class HashAggregateExec(
537547
}
538548
}
539549

550+
/**
551+
* Called by generated Java class to finish merge the fast hash map into regular map.
552+
* This is used for testing final aggregate only.
553+
*/
554+
def mergeFastHashMapForTest(
555+
fastHashMapRowIter: KVIterator[UnsafeRow, UnsafeRow],
556+
regularHashMap: UnsafeFixedWidthAggregationMap): Unit = {
557+
558+
// Create a MutableProjection to merge the buffers of same key together
559+
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
560+
val mergeProjection = MutableProjection.create(
561+
mergeExpr,
562+
aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes))
563+
val joinedRow = new JoinedRow()
564+
565+
while (fastHashMapRowIter.next()) {
566+
val key = fastHashMapRowIter.getKey
567+
val fastMapBuffer = fastHashMapRowIter.getValue
568+
val regularMapBuffer = regularHashMap.getAggregationBufferFromUnsafeRow(key)
569+
570+
// Merge the aggregation buffer of fast hash map, into the buffer with same key of
571+
// regular map
572+
mergeProjection.target(regularMapBuffer)
573+
mergeProjection(joinedRow(regularMapBuffer, fastMapBuffer))
574+
}
575+
fastHashMapRowIter.close()
576+
}
577+
540578
/**
541579
* Generate the code for output.
542580
* @return function name for the result code.
@@ -647,7 +685,7 @@ case class HashAggregateExec(
647685
(groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) ||
648686
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] ||
649687
f.dataType.isInstanceOf[CalendarIntervalType]) &&
650-
bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
688+
bufferSchema.nonEmpty
651689

652690
// For vectorized hash map, We do not support byte array based decimal type for aggregate values
653691
// as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place
@@ -663,7 +701,7 @@ case class HashAggregateExec(
663701

664702
private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = {
665703
if (!checkIfFastHashMapSupported(ctx)) {
666-
if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) {
704+
if (!Utils.isTesting) {
667705
logInfo(s"${SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key} is set to true, but"
668706
+ " current version of codegened fast hashmap does not support this aggregate.")
669707
}
@@ -740,8 +778,13 @@ case class HashAggregateExec(
740778
val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" +
741779
s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe);"
742780
val finishHashMap = if (isFastHashMapEnabled) {
781+
val finishFastHashMap = if (isTestFinalAggregateWithFallback) {
782+
s"$thisPlan.mergeFastHashMapForTest($fastHashMapTerm.rowIterator(), $hashMapTerm);"
783+
} else {
784+
s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"
785+
}
743786
s"""
744-
|$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();
787+
|$finishFastHashMap
745788
|$finishRegularHashMap
746789
""".stripMargin
747790
} else {
@@ -762,7 +805,7 @@ case class HashAggregateExec(
762805
val outputFunc = generateResultFunction(ctx)
763806

764807
def outputFromFastHashMap: String = {
765-
if (isFastHashMapEnabled) {
808+
if (isFastHashMapEnabled && !isTestFinalAggregateWithFallback) {
766809
if (isVectorizedHashMapEnabled) {
767810
outputFromVectorizedMap
768811
} else {

0 commit comments

Comments
 (0)