diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 796043fff665..d37d81753f0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -115,7 +115,13 @@ package object dsl { def getField(fieldName: String): UnresolvedExtractValue = UnresolvedExtractValue(expr, Literal(fieldName)) - def cast(to: DataType): Expression = Cast(expr, to) + def cast(to: DataType): Expression = { + if (expr.resolved && expr.dataType.sameType(to)) { + expr + } else { + Cast(expr, to) + } + } def asc: SortOrder = SortOrder(expr, Ascending) def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 95fad412002e..4c1bfcfdf7f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1612,6 +1612,48 @@ object CodeGenerator extends Logging { } } + /** + * Extracts all the input variables from references and subexpression elimination states + * for a given `expr`. This result will be used to split the generated code of + * expressions into multiple functions. + */ + def getLocalInputVariableValues( + ctx: CodegenContext, + expr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[VariableValue] = { + val argSet = mutable.Set[VariableValue]() + if (ctx.INPUT_ROW != null) { + argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow]) + } + + // Collects local variables from a given `expr` tree + val collectLocalVariable = (ev: ExprValue) => ev match { + case vv: VariableValue => argSet += vv + case _ => + } + + val stack = mutable.Stack[Expression](expr) + while (stack.nonEmpty) { + stack.pop() match { + case e if subExprs.contains(e) => + val SubExprEliminationState(isNull, value) = subExprs(e) + collectLocalVariable(value) + collectLocalVariable(isNull) + + case ref: BoundReference if ctx.currentVars != null && + ctx.currentVars(ref.ordinal) != null => + val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal) + collectLocalVariable(value) + collectLocalVariable(isNull) + + case e => + stack.pushAll(e.children) + } + } + + argSet.toSet + } + /** * Returns the name used in accessor and setter for a Java primitive type. */ @@ -1719,6 +1761,15 @@ object CodeGenerator extends Logging { 1 + params.map(paramLengthForExpr).sum } + def calculateParamLengthFromExprValues(params: Seq[ExprValue]): Int = { + def paramLengthForExpr(input: ExprValue): Int = input.javaType match { + case java.lang.Long.TYPE | java.lang.Double.TYPE => 2 + case _ => 1 + } + // Initial value is 1 for `this`. + 1 + params.map(paramLengthForExpr).sum + } + /** * In Java, a method descriptor is valid only if it represents method parameters with a total * length less than a pre-defined constant. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 3bb3c602f775..d9393b9df6bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -143,7 +143,10 @@ trait Block extends TreeNode[Block] with JavaCode { case _ => code.trim } - def length: Int = toString.length + def length: Int = { + // Returns a code length without comments + CodeFormatter.stripExtraNewLinesAndComments(toString).length + } def isEmpty: Boolean = toString.isEmpty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 293d28e93039..f54d5f167856 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -354,12 +354,14 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - val value = eval.isNull match { - case TrueLiteral => FalseLiteral - case FalseLiteral => TrueLiteral - case v => JavaCode.isNullExpression(s"!$v") + val (value, newCode) = eval.isNull match { + case TrueLiteral => (FalseLiteral, EmptyBlock) + case FalseLiteral => (TrueLiteral, EmptyBlock) + case v => + val value = ctx.freshName("value") + (JavaCode.variable(value, BooleanType), code"boolean $value = !$v;") } - ExprCode(code = eval.code, isNull = FalseLiteral, value = value) + ExprCode(code = eval.code + newCode, isNull = FalseLiteral, value = value) } override def sql: String = s"(${child.sql} IS NOT NULL)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 52990cb6a244..09ac23711739 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1047,6 +1047,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val CODEGEN_SPLIT_AGGREGATE_FUNC = + buildConf("spark.sql.codegen.aggregate.splitAggregateFunc.enabled") + .internal() + .doc("When true, the code generator would split aggregate code into individual methods " + + "instead of a single big method. This can be used to avoid oversized function that " + + "can miss the opportunity of JIT optimization.") + .booleanConf + .createWithDefault(true) + val MAX_NESTED_VIEW_DEPTH = buildConf("spark.sql.view.maxNestedViewDepth") .internal() @@ -2310,6 +2319,8 @@ class SQLConf extends Serializable with Logging { def cartesianProductExecBufferSpillThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) + def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC) + def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) def starSchemaDetection: Boolean = getConf(STARSCHEMA_DETECTION) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4a95f7638133..9242583d3671 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.aggregate import java.util.concurrent.TimeUnit._ +import scala.collection.mutable + import org.apache.spark.TaskContext import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.rdd.RDD @@ -174,8 +176,9 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer. Only used for aggregation without keys. - private var bufVars: Seq[ExprCode] = _ + // The variables are used as aggregation buffers and each aggregate function has one or more + // ExprCode to initialize its buffer slots. Only used for aggregation without keys. + private var bufVars: Seq[Seq[ExprCode]] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") @@ -184,27 +187,30 @@ case class HashAggregateExec( // generate variables for aggregation buffer val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val initExpr = functions.flatMap(f => f.initialValues) - bufVars = initExpr.map { e => - val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") - // The initial expression should not access any column - val ev = e.genCode(ctx) - val initVars = code""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; - """.stripMargin - ExprCode( - ev.code + initVars, - JavaCode.isNullGlobal(isNull), - JavaCode.global(value, e.dataType)) + val initExpr = functions.map(f => f.initialValues) + bufVars = initExpr.map { exprs => + exprs.map { e => + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") + // The initial expression should not access any column + val ev = e.genCode(ctx) + val initVars = code""" + |$isNull = ${ev.isNull}; + |$value = ${ev.value}; + """.stripMargin + ExprCode( + ev.code + initVars, + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, e.dataType)) + } } - val initBufVar = evaluateVariables(bufVars) + val flatBufVars = bufVars.flatten + val initBufVar = evaluateVariables(flatBufVars) // generate variables for output val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { // evaluate aggregate results - ctx.currentVars = bufVars + ctx.currentVars = flatBufVars val aggResults = bindReferences( functions.map(_.evaluateExpression), aggregateBufferAttributes).map(_.genCode(ctx)) @@ -218,7 +224,7 @@ case class HashAggregateExec( """.stripMargin) } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { // output the aggregate buffer directly - (bufVars, "") + (flatBufVars, "") } else { // no aggregate function, the result should be literals val resultVars = resultExpressions.map(_.genCode(ctx)) @@ -255,11 +261,85 @@ case class HashAggregateExec( """.stripMargin } + private def isValidParamLength(paramLength: Int): Boolean = { + // This config is only for testing + sqlContext.getConf("spark.sql.HashAggregateExec.validParamLength", null) match { + case null | "" => CodeGenerator.isValidParamLength(paramLength) + case validLength => paramLength <= validLength.toInt + } + } + + // Splits aggregate code into small functions because the most of JVM implementations + // can not compile too long functions. Returns None if we are not able to split the given code. + // + // Note: The difference from `CodeGenerator.splitExpressions` is that we define an individual + // function for each aggregation function (e.g., SUM and AVG). For example, in a query + // `SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)`, we define two functions + // for `SUM(a)` and `AVG(a)`. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggNames: Seq[String], + aggBufferUpdatingExprs: Seq[Seq[Expression]], + aggCodeBlocks: Seq[Block], + subExprs: Map[Expression, SubExprEliminationState]): Option[String] = { + val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil } + if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) { + // `SimpleExprValue`s cannot be used as an input variable for split functions, so + // we give up splitting functions if it exists in `subExprs`. + None + } else { + val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => + val inputVarsForOneFunc = aggExprsForOneFunc.map( + CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) + + // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit + if (isValidParamLength(paramLength)) { + Some(inputVarsForOneFunc) + } else { + None + } + } + + // Checks if all the aggregate code can be split into pieces. + // If the parameter length of at lease one `aggExprsForOneFunc` goes over the limit, + // we totally give up splitting aggregate code. + if (inputVars.forall(_.isDefined)) { + val splitCodes = inputVars.flatten.zipWithIndex.map { case (args, i) => + val doAggFunc = ctx.freshName(s"doAggregate_${aggNames(i)}") + val argList = args.map(v => s"${v.javaType.getName} ${v.variableName}").mkString(", ") + val doAggFuncName = ctx.addNewFunction(doAggFunc, + s""" + |private void $doAggFunc($argList) throws java.io.IOException { + | ${aggCodeBlocks(i)} + |} + """.stripMargin) + + val inputVariables = args.map(_.variableName).mkString(", ") + s"$doAggFuncName($inputVariables);" + } + Some(splitCodes.mkString("\n").trim) + } else { + val errMsg = "Failed to split aggregate code into small functions because the parameter " + + "length of at least one split function went over the JVM limit: " + + CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + if (Utils.isTesting) { + throw new IllegalStateException(errMsg) + } else { + logInfo(errMsg) + None + } + } + } + } + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output - val updateExpr = aggregateExpressions.flatMap { e => + // To individually generate code for each aggregate function, an element in `updateExprs` holds + // all the expressions for the buffer of an aggregation function. + val updateExprs = aggregateExpressions.map { e => e.mode match { case Partial | Complete => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions @@ -267,28 +347,56 @@ case class HashAggregateExec( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } - ctx.currentVars = bufVars ++ input - val boundUpdateExpr = bindReferences(updateExpr, inputAttrs) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + ctx.currentVars = bufVars.flatten ++ input + val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => + bindReferences(updateExprsForOneFunc, inputAttrs) + } + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) + val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => + ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) + } } - // aggregate buffer should be updated atomic - val updates = aggVals.zipWithIndex.map { case (ev, i) => - s""" - | ${bufVars(i).isNull} = ${ev.isNull}; - | ${bufVars(i).value} = ${ev.value}; + + val aggNames = functions.map(_.prettyName) + val aggCodeBlocks = bufferEvals.zipWithIndex.map { case (bufferEvalsForOneFunc, i) => + val bufVarsForOneFunc = bufVars(i) + // All the update code for aggregation buffers should be placed in the end + // of each aggregation function code. + val updates = bufferEvalsForOneFunc.zip(bufVarsForOneFunc).map { case (ev, bufVar) => + s""" + |${bufVar.isNull} = ${ev.isNull}; + |${bufVar.value} = ${ev.value}; + """.stripMargin + } + code""" + |// do aggregate for ${aggNames(i)} + |// evaluate aggregate function + |${evaluateVariables(bufferEvalsForOneFunc)} + |// update aggregation buffers + |${updates.mkString("\n").trim} """.stripMargin } + + val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc && + aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { + val maybeSplitCode = splitAggregateExpressions( + ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) + + maybeSplitCode.getOrElse { + aggCodeBlocks.fold(EmptyBlock)(_ + _).code + } + } else { + aggCodeBlocks.fold(EmptyBlock)(_ + _).code + } + s""" - | // do aggregate - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(aggVals)} - | // update aggregation buffer - | ${updates.mkString("\n").trim} + |// do aggregate + |// common sub-expressions + |$effectiveCodes + |// evaluate aggregate functions and update aggregation buffers + |$codeToEvalAggFunc """.stripMargin } @@ -745,8 +853,10 @@ case class HashAggregateExec( val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") val fastRowBuffer = ctx.freshName("fastAggBuffer") - // only have DeclarativeAggregate - val updateExpr = aggregateExpressions.flatMap { e => + // To individually generate code for each aggregate function, an element in `updateExprs` holds + // all the expressions for the buffer of an aggregation function. + val updateExprs = aggregateExpressions.map { e => + // only have DeclarativeAggregate e.mode match { case Partial | Complete => e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions @@ -824,25 +934,70 @@ case class HashAggregateExec( // generating input columns, we use `currentVars`. ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input + val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName) + // Computes start offsets for each aggregation function code + // in the underlying buffer row. + val bufferStartOffsets = { + val offsets = mutable.ArrayBuffer[Int]() + var curOffset = 0 + updateExprs.foreach { exprsForOneFunc => + offsets += curOffset + curOffset += exprsForOneFunc.length + } + offsets.toArray + } + val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => + bindReferences(updateExprsForOneFunc, inputAttr) + } + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) + val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => + ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) + } } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + + val aggCodeBlocks = updateExprs.indices.map { i => + val rowBufferEvalsForOneFunc = unsafeRowBufferEvals(i) + val boundUpdateExprsForOneFunc = boundUpdateExprs(i) + val bufferOffset = bufferStartOffsets(i) + + // All the update code for aggregation buffers should be placed in the end + // of each aggregation function code. + val updateRowBuffers = rowBufferEvalsForOneFunc.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsForOneFunc(j) + val dt = updateExpr.dataType + val nullable = updateExpr.nullable + CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, ev, nullable) + } + code""" + |// evaluate aggregate function for ${aggNames(i)} + |${evaluateVariables(rowBufferEvalsForOneFunc)} + |// update unsafe row buffer + |${updateRowBuffers.mkString("\n").trim} + """.stripMargin } + + val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc && + aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { + val maybeSplitCode = splitAggregateExpressions( + ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) + + maybeSplitCode.getOrElse { + aggCodeBlocks.fold(EmptyBlock)(_ + _).code + } + } else { + aggCodeBlocks.fold(EmptyBlock)(_ + _).code + } + s""" |// common sub-expressions |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(unsafeRowBufferEvals)} - |// update unsafe row buffer - |${updateUnsafeRowBuffer.mkString("\n").trim} + |// evaluate aggregate functions and update aggregation buffers + |$codeToEvalAggFunc """.stripMargin } @@ -850,16 +1005,48 @@ case class HashAggregateExec( if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = bindReferences(updateExpr, inputAttr) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => + bindReferences(updateExprsForOneFunc, inputAttr) + } + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) val effectiveCodes = subExprs.codes.mkString("\n") - val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) + val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => + ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) + } } - val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - CodeGenerator.updateColumn( - fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized = true) + + val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsForOneFunc, i) => + val boundUpdateExprsForOneFunc = boundUpdateExprs(i) + val bufferOffset = bufferStartOffsets(i) + // All the update code for aggregation buffers should be placed in the end + // of each aggregation function code. + val updateRowBuffer = fastRowEvalsForOneFunc.zipWithIndex.map { case (ev, j) => + val updateExpr = boundUpdateExprsForOneFunc(j) + val dt = updateExpr.dataType + val nullable = updateExpr.nullable + CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, ev, nullable, + isVectorized = true) + } + code""" + |// evaluate aggregate function for ${aggNames(i)} + |${evaluateVariables(fastRowEvalsForOneFunc)} + |// update fast row + |${updateRowBuffer.mkString("\n").trim} + """.stripMargin + } + + + val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc && + aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { + val maybeSplitCode = splitAggregateExpressions( + ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) + + maybeSplitCode.getOrElse { + aggCodeBlocks.fold(EmptyBlock)(_ + _).code + } + } else { + aggCodeBlocks.fold(EmptyBlock)(_ + _).code } // If vectorized fast hash map is on, we first generate code to update row @@ -869,10 +1056,8 @@ case class HashAggregateExec( |if ($fastRowBuffer != null) { | // common sub-expressions | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(fastRowEvals)} - | // update fast row - | ${updateFastRow.mkString("\n").trim} + | // evaluate aggregate functions and update aggregation buffers + | $codeToEvalAggFunc |} else { | $updateRowInRegularHashMap |} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 0ea16a1a15d6..d8727d5b584f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -398,4 +398,25 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession { }.isDefined, "LocalTableScanExec should be within a WholeStageCodegen domain.") } + + test("Give up splitting aggregate code if a parameter length goes over the limit") { + withSQLConf( + SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC.key -> "true", + SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1", + "spark.sql.HashAggregateExec.validParamLength" -> "0") { + withTable("t") { + val expectedErrMsg = "Failed to split aggregate code into small functions" + Seq( + // Test case without keys + "SELECT AVG(v) FROM VALUES(1) t(v)", + // Tet case with keys + "SELECT k, AVG(v) FROM VALUES((1, 1)) t(k, v) GROUP BY k").foreach { query => + val errMsg = intercept[IllegalStateException] { + sql(query).collect + }.getMessage + assert(errMsg.contains(expectedErrMsg)) + } + } + } + } }