@@ -128,6 +128,16 @@ case class HashAggregateExec(
128128 // all the mode of aggregate expressions
129129 private val modes = aggregateExpressions.map(_.mode).distinct
130130
131+ // This is for testing final aggregate with number-of-rows-based fall back as specified in
132+ // `testFallbackStartsAt`. In this scenario, there might be same keys exist in both fast and
133+ // regular hash map. So the aggregation buffers from both maps need to be merged together
134+ // to avoid correctness issue.
135+ //
136+ // This scenario only happens in unit test with number-of-rows-based fall back.
137+ // There should not be same keys in both maps with size-based fall back in production.
138+ private val isTestFinalAggregateWithFallback : Boolean = testFallbackStartsAt.isDefined &&
139+ (modes.contains(Final ) || modes.contains(Complete ))
140+
131141 override def usedInputs : AttributeSet = inputSet
132142
133143 override def supportCodegen : Boolean = {
@@ -537,6 +547,34 @@ case class HashAggregateExec(
537547 }
538548 }
539549
550+ /**
551+ * Called by generated Java class to finish merge the fast hash map into regular map.
552+ * This is used for testing final aggregate only.
553+ */
554+ def mergeFastHashMapForTest (
555+ fastHashMapRowIter : KVIterator [UnsafeRow , UnsafeRow ],
556+ regularHashMap : UnsafeFixedWidthAggregationMap ): Unit = {
557+
558+ // Create a MutableProjection to merge the buffers of same key together
559+ val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
560+ val mergeProjection = MutableProjection .create(
561+ mergeExpr,
562+ aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes))
563+ val joinedRow = new JoinedRow ()
564+
565+ while (fastHashMapRowIter.next()) {
566+ val key = fastHashMapRowIter.getKey
567+ val fastMapBuffer = fastHashMapRowIter.getValue
568+ val regularMapBuffer = regularHashMap.getAggregationBufferFromUnsafeRow(key)
569+
570+ // Merge the aggregation buffer of fast hash map, into the buffer with same key of
571+ // regular map
572+ mergeProjection.target(regularMapBuffer)
573+ mergeProjection(joinedRow(regularMapBuffer, fastMapBuffer))
574+ }
575+ fastHashMapRowIter.close()
576+ }
577+
540578 /**
541579 * Generate the code for output.
542580 * @return function name for the result code.
@@ -647,7 +685,7 @@ case class HashAggregateExec(
647685 (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator .isPrimitiveType(f.dataType) ||
648686 f.dataType.isInstanceOf [DecimalType ] || f.dataType.isInstanceOf [StringType ] ||
649687 f.dataType.isInstanceOf [CalendarIntervalType ]) &&
650- bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge )
688+ bufferSchema.nonEmpty
651689
652690 // For vectorized hash map, We do not support byte array based decimal type for aggregate values
653691 // as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place
@@ -663,7 +701,7 @@ case class HashAggregateExec(
663701
664702 private def enableTwoLevelHashMap (ctx : CodegenContext ): Unit = {
665703 if (! checkIfFastHashMapSupported(ctx)) {
666- if (modes.forall(mode => mode == Partial || mode == PartialMerge ) && ! Utils .isTesting) {
704+ if (! Utils .isTesting) {
667705 logInfo(s " ${SQLConf .ENABLE_TWOLEVEL_AGG_MAP .key} is set to true, but "
668706 + " current version of codegened fast hashmap does not support this aggregate." )
669707 }
@@ -740,8 +778,13 @@ case class HashAggregateExec(
740778 val finishRegularHashMap = s " $iterTerm = $thisPlan.finishAggregate( " +
741779 s " $hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe); "
742780 val finishHashMap = if (isFastHashMapEnabled) {
781+ val finishFastHashMap = if (isTestFinalAggregateWithFallback) {
782+ s " $thisPlan.mergeFastHashMapForTest( $fastHashMapTerm.rowIterator(), $hashMapTerm); "
783+ } else {
784+ s " $iterTermForFastHashMap = $fastHashMapTerm.rowIterator(); "
785+ }
743786 s """
744- | $iterTermForFastHashMap = $fastHashMapTerm .rowIterator();
787+ | $finishFastHashMap
745788 | $finishRegularHashMap
746789 """ .stripMargin
747790 } else {
@@ -762,7 +805,7 @@ case class HashAggregateExec(
762805 val outputFunc = generateResultFunction(ctx)
763806
764807 def outputFromFastHashMap : String = {
765- if (isFastHashMapEnabled) {
808+ if (isFastHashMapEnabled && ! isTestFinalAggregateWithFallback ) {
766809 if (isVectorizedHashMapEnabled) {
767810 outputFromVectorizedMap
768811 } else {
0 commit comments