diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9242583d3671..7ead180d869b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -234,30 +234,30 @@ case class HashAggregateExec( val doAgg = ctx.freshName("doAggregateWithoutKey") val doAggFuncName = ctx.addNewFunction(doAgg, s""" - | private void $doAgg() throws java.io.IOException { - | // initialize aggregation buffer - | $initBufVar + |private void $doAgg() throws java.io.IOException { + | // initialize aggregation buffer + | $initBufVar | - | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - | } + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + |} """.stripMargin) val numOutput = metricTerm(ctx, "numOutputRows") val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") s""" - | while (!$initAgg) { - | $initAgg = true; - | long $beforeAgg = System.nanoTime(); - | $doAggFuncName(); - | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); + |while (!$initAgg) { + | $initAgg = true; + | long $beforeAgg = System.nanoTime(); + | $doAggFuncName(); + | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); | - | // output the result - | ${genResult.trim} + | // output the result + | ${genResult.trim} | - | $numOutput.add(1); - | ${consume(ctx, resultVars).trim} - | } + | $numOutput.add(1); + | ${consume(ctx, resultVars).trim} + |} """.stripMargin } @@ -581,12 +581,12 @@ case class HashAggregateExec( val evaluateNondeterministicResults = evaluateNondeterministicVariables(output, resultVars, resultExpressions) s""" - $evaluateKeyVars - $evaluateBufferVars - $evaluateAggResults - $evaluateNondeterministicResults - ${consume(ctx, resultVars)} - """ + |$evaluateKeyVars + |$evaluateBufferVars + |$evaluateAggResults + |$evaluateNondeterministicResults + |${consume(ctx, resultVars)} + """.stripMargin } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { // resultExpressions are Attributes of groupingExpressions and aggregateBufferAttributes. assert(resultExpressions.forall(_.isInstanceOf[Attribute])) @@ -613,10 +613,10 @@ case class HashAggregateExec( resultExpressions, inputAttrs).map(_.genCode(ctx)) s""" - $evaluateKeyVars - $evaluateResultBufferVars - ${consume(ctx, resultVars)} - """ + |$evaluateKeyVars + |$evaluateResultBufferVars + |${consume(ctx, resultVars)} + """.stripMargin } else { // generate result based on grouping key ctx.INPUT_ROW = keyTerm @@ -627,18 +627,18 @@ case class HashAggregateExec( val evaluateNondeterministicResults = evaluateNondeterministicVariables(output, resultVars, resultExpressions) s""" - $evaluateNondeterministicResults - ${consume(ctx, resultVars)} - """ + |$evaluateNondeterministicResults + |${consume(ctx, resultVars)} + """.stripMargin } ctx.addNewFunction(funcName, s""" - private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm) - throws java.io.IOException { - $numOutput.add(1); - $body - } - """) + |private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm) + | throws java.io.IOException { + | $numOutput.add(1); + | $body + |} + """.stripMargin) } /** @@ -829,17 +829,16 @@ case class HashAggregateExec( val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") s""" - if (!$initAgg) { - $initAgg = true; - long $beforeAgg = System.nanoTime(); - $doAggFuncName(); - $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); - } - - // output the result - $outputFromFastHashMap - $outputFromRegularHashMap - """ + |if (!$initAgg) { + | $initAgg = true; + | long $beforeAgg = System.nanoTime(); + | $doAggFuncName(); + | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); + |} + |// output the result + |$outputFromFastHashMap + |$outputFromRegularHashMap + """.stripMargin } private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { @@ -1098,14 +1097,11 @@ case class HashAggregateExec( // continue to do in-memory aggregation and spilling until all the rows had been processed. // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" - $declareRowBuffer - - $findOrInsertHashMap - - $incCounter - - $updateRowInHashMap - """ + |$declareRowBuffer + |$findOrInsertHashMap + |$incCounter + |$updateRowInHashMap + """.stripMargin } override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields)