diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 8d10f6cd2952..58f0f2948002 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -408,29 +408,11 @@ class CodegenContext extends Logging { partitionInitializationStatements.mkString("\n") } - /** - * Holds expressions that are equivalent. Used to perform subexpression elimination - * during codegen. - * - * For expressions that appear more than once, generate additional code to prevent - * recomputing the value. - * - * For example, consider two expression generated from this SQL statement: - * SELECT (col1 + col2), (col1 + col2) / col3. - * - * equivalentExpressions will match the tree containing `col1 + col2` and it will only - * be evaluated once. - */ - private val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - // Foreach expression that is participating in subexpression elimination, the state to use. // Visible for testing. private[expressions] var subExprEliminationExprs = Map.empty[ExpressionEquals, SubExprEliminationState] - // The collection of sub-expression result resetting methods that need to be called on each row. - private val subexprFunctions = mutable.ArrayBuffer.empty[String] - val outerClassName = "OuterClass" /** @@ -1021,9 +1003,8 @@ class CodegenContext extends Logging { * Returns the code for subexpression elimination after splitting it if necessary. */ def subexprFunctionsCode: String = { - // Whole-stage codegen's subexpression elimination is handled in another code path - assert(currentVars == null || subexprFunctions.isEmpty) - splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) + val subExprCode = evaluateSubExprEliminationState(subExprEliminationExprs.values) + splitExpressionsWithCurrentInputs(subExprCode, "subexprFunc_split") } /** @@ -1048,20 +1029,17 @@ class CodegenContext extends Logging { * evaluating a subexpression, this method will clean up the code block to avoid duplicate * evaluation. */ - def evaluateSubExprEliminationState(subExprStates: Iterable[SubExprEliminationState]): String = { - val code = new StringBuilder() - - subExprStates.foreach { state => - val currentCode = evaluateSubExprEliminationState(state.children) + "\n" + state.eval.code - code.append(currentCode + "\n") + def evaluateSubExprEliminationState( + subExprStates: Iterable[SubExprEliminationState]): Seq[String] = { + subExprStates.flatMap { state => + val currentCode = evaluateSubExprEliminationState(state.children) :+ state.eval.code.toString state.eval.code = EmptyBlock - } - - code.toString() + currentCode + }.toSeq } /** - * Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen. + * Checks and sets up the state and codegen for subexpression elimination. * * This finds the common subexpressions, generates the code snippets that evaluate those * expressions and populates the mapping of common subexpressions to the generated code snippets. @@ -1094,10 +1072,10 @@ class CodegenContext extends Logging { * (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression * evaluation, we can look for generated subexpressions and do replacement. */ - def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { + def subexpressionElimination(expressions: Seq[Expression]): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - val localSubExprEliminationExprsForNonSplit = + val localSubExprEliminationExprs = mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. @@ -1110,8 +1088,28 @@ class CodegenContext extends Logging { val nonSplitCode = { val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState] commonExprs.map { expr => - withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { + withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { val eval = expr.genCode(this) + + val value = addMutableState(javaType(expr.dataType), "subExprValue") + + val isNullLiteral = eval.isNull match { + case TrueLiteral | FalseLiteral => true + case _ => false + } + val (isNull, isNullEvalCode) = if (!isNullLiteral) { + val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") + } else { + (eval.isNull, "") + } + + val code = code""" + |${eval.code} + |$isNullEvalCode + |$value = ${eval.value}; + """ + // Collects other subexpressions from the children. val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] expr.foreach { e => @@ -1120,8 +1118,10 @@ class CodegenContext extends Logging { case _ => } } - val state = SubExprEliminationState(eval, childrenSubExprs.toSeq) - localSubExprEliminationExprsForNonSplit.put(ExpressionEquals(expr), state) + val state = SubExprEliminationState( + ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), + childrenSubExprs.toSeq) + localSubExprEliminationExprs.put(ExpressionEquals(expr), state) allStates += state Seq(eval) } @@ -1141,38 +1141,18 @@ class CodegenContext extends Logging { val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold val (subExprsMap, exprCodes) = if (needSplit) { if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { - val localSubExprEliminationExprs = - mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState] commonExprs.zipWithIndex.foreach { case (expr, i) => - val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { - Seq(expr.genCode(this)) - }.head - - val value = addMutableState(javaType(expr.dataType), "subExprValue") - - val isNullLiteral = eval.isNull match { - case TrueLiteral | FalseLiteral => true - case _ => false - } - val (isNull, isNullEvalCode) = if (!isNullLiteral) { - val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - (JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};") - } else { - (eval.isNull, "") - } - // Generate the code for this expression tree and wrap it in a function. val fnName = freshName("subExpr") val inputVars = inputVarsForAllFuncs(i) val argList = inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") + val subExprState = localSubExprEliminationExprs.remove(ExpressionEquals(expr)).get val fn = s""" |private void $fnName(${argList.mkString(", ")}) { - | ${eval.code} - | $isNullEvalCode - | $value = ${eval.value}; + | ${subExprState.eval.code} |} """.stripMargin @@ -1188,7 +1168,7 @@ class CodegenContext extends Logging { val inputVariables = inputVars.map(_.variableName).mkString(", ") val code = code"${addNewFunction(fnName, fn)}($inputVariables);" val state = SubExprEliminationState( - ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), + subExprState.eval.copy(code = code), childrenSubExprs.toSeq) localSubExprEliminationExprs.put(ExpressionEquals(expr), state) } @@ -1201,67 +1181,16 @@ class CodegenContext extends Logging { throw new IllegalStateException(errMsg) } else { logInfo(errMsg) - (localSubExprEliminationExprsForNonSplit, Seq.empty) + (localSubExprEliminationExprs, Seq.empty) } } } else { - (localSubExprEliminationExprsForNonSplit, Seq.empty) + (localSubExprEliminationExprs, Seq.empty) } + subExprsMap.foreach(subExprEliminationExprs += _) SubExprCodes(subExprsMap.toMap, exprCodes.flatten) } - /** - * Checks and sets up the state and codegen for subexpression elimination. This finds the - * common subexpressions, generates the functions that evaluate those expressions and populates - * the mapping of common subexpressions to the generated functions. - */ - private def subexpressionElimination(expressions: Seq[Expression]): Unit = { - // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree(_)) - - // Get all the expressions that appear at least twice and set up the state for subexpression - // elimination. - val commonExprs = equivalentExpressions.getCommonSubexpressions - commonExprs.foreach { expr => - val fnName = freshName("subExpr") - val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - val value = addMutableState(javaType(expr.dataType), "subExprValue") - - // Generate the code for this expression tree and wrap it in a function. - val eval = expr.genCode(this) - val fn = - s""" - |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code} - | $isNull = ${eval.isNull}; - | $value = ${eval.value}; - |} - """.stripMargin - - // Add a state and a mapping of the common subexpressions that are associate with this - // state. Adding this expression to subExprEliminationExprMap means it will call `fn` - // when it is code generated. This decision should be a cost based one. - // - // The cost of doing subexpression elimination is: - // 1. Extra function call, although this is probably *good* as the JIT can decide to - // inline or not. - // The benefit doing subexpression elimination is: - // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 - // above. - // 2. Less code. - // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with - // at least two nodes) as the cost of doing it is expected to be low. - - val subExprCode = s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - subexprFunctions += subExprCode - val state = SubExprEliminationState( - ExprCode(code"$subExprCode", - JavaCode.isNullGlobal(isNull), - JavaCode.global(value, expr.dataType))) - subExprEliminationExprs += ExpressionEquals(expr) -> state - } - } - /** * Generates code for expressions. If doSubexpressionElimination is true, subexpression * elimination will be performed. Subexpression elimination assumes that the code for each diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index f369635a3267..49c87b9e224f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -278,11 +278,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel ExprCode(TrueLiteral, oneVar), ExprCode(TrueLiteral, twoVar)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) - ctx.withSubExprEliminationExprs(subExprs.states) { - exprs.map(_.genCode(ctx)) - } - val subExprsCode = ctx.evaluateSubExprEliminationState(subExprs.states.values) + ctx.subexpressionElimination(exprs) + val subExprsCode = ctx.subexprFunctionsCode val codeBody = s""" public java.lang.Object generate(Object[] references) { @@ -408,7 +405,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val exprs = Seq(add1, add1, add2, add2) val ctx = new CodegenContext() - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val subExprs = ctx.subexpressionElimination(exprs) val add2State = subExprs.states(ExpressionEquals(add2)) val add1State = subExprs.states(ExpressionEquals(add1)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala index 1377a9842231..f03c0cd60ae5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala @@ -207,12 +207,14 @@ trait AggregateCodegenSupport val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) + val subExprs = if (conf.subexpressionEliminationEnabled) { + ctx.subexpressionElimination(boundUpdateExprs.flatten).states + } else { + Map.empty[ExpressionEquals, SubExprEliminationState] + } + val effectiveCodes = ctx.subexprFunctionsCode val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => - ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExprsForOneFunc.map(_.genCode(ctx)) - } + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } val aggNames = functions.map(_.prettyName) @@ -256,11 +258,11 @@ trait AggregateCodegenSupport boundUpdateExprs: Seq[Seq[Expression]], aggNames: Seq[String], aggCodeBlocks: Seq[Block], - subExprs: SubExprCodes): String = { + subExprs: Map[ExpressionEquals, SubExprEliminationState]): String = { val aggCodes = if (conf.codegenSplitAggregateFunc && aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) { val maybeSplitCodes = splitAggregateExpressions( - ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states) + ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs) maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code)) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b942907b6752..d0601c9bfbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -728,12 +728,14 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) + val subExprs = if (conf.subexpressionEliminationEnabled) { + ctx.subexpressionElimination(boundUpdateExprs.flatten).states + } else { + Map.empty[ExpressionEquals, SubExprEliminationState] + } + val effectiveCodes = ctx.subexprFunctionsCode val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => - ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExprsForOneFunc.map(_.genCode(ctx)) - } + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } val aggCodeBlocks = updateExprs.indices.map { i => @@ -774,12 +776,14 @@ case class HashAggregateExec( val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => bindReferences(updateExprsForOneFunc, inputAttrs) } - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) + val subExprs = if (conf.subexpressionEliminationEnabled) { + ctx.subexpressionElimination(boundUpdateExprs.flatten).states + } else { + Map.empty[ExpressionEquals, SubExprEliminationState] + } + val effectiveCodes = ctx.subexprFunctionsCode val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => - ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExprsForOneFunc.map(_.genCode(ctx)) - } + boundUpdateExprsForOneFunc.map(_.genCode(ctx)) } val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsForOneFunc, i) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 2fd799355070..bd28a89edade 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -67,17 +67,13 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val exprs = bindReferences[Expression](projectList, child.output) - val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) { - // subexpression elimination - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) - val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { - exprs.map(_.genCode(ctx)) - } - (ctx.evaluateSubExprEliminationState(subExprs.states.values), genVars, - subExprs.exprCodesNeedEvaluate) + val localValInputs = if (conf.subexpressionEliminationEnabled) { + ctx.subexpressionElimination(exprs).exprCodesNeedEvaluate } else { - ("", exprs.map(_.genCode(ctx)), Seq.empty) + Seq.empty } + val resultVars = exprs.map(_.genCode(ctx)) + val subExprsCode = ctx.subexprFunctionsCode // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)