diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4b9ae9a7a81e..c39db3511926 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -139,7 +139,10 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated which means that the code to evaluate it has already been added // as a function before. In that case, we just re-use it. - ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value) + ExprCode( + ctx.registerComment(this.toString), + subExprState.eval.isNull, + subExprState.eval.value) }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9831b13ea754..07fd2fe4e76c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -76,24 +76,38 @@ object ExprCode { /** * State used for subexpression elimination. * - * @param isNull A term that holds a boolean value representing whether the expression evaluated - * to null. - * @param value A term for a value of a common sub-expression. Not valid if `isNull` - * is set to `true`. + * @param eval The source code for evaluating the subexpression. + * @param children The sequence of subexpressions as the children expressions. Before + * evaluating this subexpression, we should evaluate all children + * subexpressions first. This is used if we want to selectively evaluate + * particular subexpressions, instead of all at once. In the case, we need + * to make sure we evaluate all children subexpressions too. */ -case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) +case class SubExprEliminationState( + eval: ExprCode, + children: Seq[SubExprEliminationState]) + +object SubExprEliminationState { + def apply(eval: ExprCode): SubExprEliminationState = { + new SubExprEliminationState(eval, Seq.empty) + } + + def apply( + eval: ExprCode, + children: Seq[SubExprEliminationState]): SubExprEliminationState = { + new SubExprEliminationState(eval, children.reverse) + } +} /** * Codes and common subexpressions mapping used for subexpression elimination. * - * @param codes Strings representing the codes that evaluate common subexpressions. * @param states Foreach expression that is participating in subexpression elimination, * the state to use. * @param exprCodesNeedEvaluate Some expression codes that need to be evaluated before * calling common subexpressions. */ case class SubExprCodes( - codes: Seq[String], states: Map[Expression, SubExprEliminationState], exprCodesNeedEvaluate: Seq[ExprCode]) @@ -1030,11 +1044,55 @@ class CodegenContext extends Logging { } /** - * Checks and sets up the state and codegen for subexpression elimination. This finds the - * common subexpressions, generates the code snippets that evaluate those expressions and - * populates the mapping of common subexpressions to the generated code snippets. The generated - * code snippets will be returned and should be inserted into generated codes before these - * common subexpressions actually are used first time. + * Evaluates a sequence of `SubExprEliminationState` which represent subexpressions. After + * evaluating a subexpression, this method will clean up the code block to avoid duplicate + * evaluation. + */ + def evaluateSubExprEliminationState(subExprStates: Iterable[SubExprEliminationState]): String = { + val code = new StringBuilder() + + subExprStates.foreach { state => + val currentCode = evaluateSubExprEliminationState(state.children) + "\n" + state.eval.code + code.append(currentCode + "\n") + state.eval.code = EmptyBlock + } + + code.toString() + } + + /** + * Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen. + * + * This finds the common subexpressions, generates the code snippets that evaluate those + * expressions and populates the mapping of common subexpressions to the generated code snippets. + * + * The generated code snippet for subexpression is wrapped in `SubExprEliminationState`, which + * contains an `ExprCode` and the children `SubExprEliminationState` if any. The `ExprCode` + * includes java source code, result variable name and is-null variable name of the subexpression. + * + * Besides, this also returns a sequences of `ExprCode` which are expression codes that need to + * be evaluated (as their input parameters) before evaluating subexpressions. + * + * To evaluate the returned subexpressions, please call `evaluateSubExprEliminationState` with + * the `SubExprEliminationState`s to be evaluated. During generating the code, it will cleanup + * the states to avoid duplicate evaluation. + * + * The details of subexpression generation: + * 1. Gets subexpression set. See `EquivalentExpressions`. + * 2. Generate code of subexpressions as a whole block of code (non-split case) + * 3. Check if the total length of the above block is larger than the split-threshold. If so, + * try to split it in step 4, otherwise returning the non-split code block. + * 4. Check if parameter lengths of all subexpressions satisfy the JVM limitation, if so, + * try to split, otherwise returning the non-split code block. + * 5. For each subexpression, generating a function and put the code into it. To evaluate the + * subexpression, just call the function. + * + * The explanation of subexpression codegen: + * 1. Wrapping in `withSubExprEliminationExprs` call with current subexpression map. Each + * subexpression may depends on other subexpressions (children). So when generating code + * for subexpressions, we iterate over each subexpression and put the mapping between + * (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression + * evaluation, we can look for generated subexpressions and do replacement. */ def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping @@ -1049,17 +1107,25 @@ class CodegenContext extends Logging { // elimination. val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) - val nonSplitExprCode = { + val nonSplitCode = { + val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState] commonExprs.map { exprs => - val eval = withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { + withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { val eval = exprs.head.genCode(this) - // Generate the code for this expression tree. - val state = SubExprEliminationState(eval.isNull, eval.value) + // Collects other subexpressions from the children. + val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] + exprs.head.foreach { + case e if subExprEliminationExprs.contains(e) => + childrenSubExprs += subExprEliminationExprs(e) + case _ => + } + val state = SubExprEliminationState(eval, childrenSubExprs.toSeq) exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state)) + allStates += state Seq(eval) - }.head - eval.code.toString + } } + allStates.toSeq } // For some operators, they do not require all its child's outputs to be evaluated in advance. @@ -1071,14 +1137,13 @@ class CodegenContext extends Logging { (inputVars.toSeq, exprCodes.toSeq) }.unzip - val splitThreshold = SQLConf.get.methodSplitThreshold - - val (codes, subExprsMap, exprCodes) = if (nonSplitExprCode.map(_.length).sum > splitThreshold) { + val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold + val (subExprsMap, exprCodes) = if (needSplit) { if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] - val splitCodes = commonExprs.zipWithIndex.map { case (exprs, i) => + commonExprs.zipWithIndex.foreach { case (exprs, i) => val expr = exprs.head val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { Seq(expr.genCode(this)) @@ -1111,24 +1176,34 @@ class CodegenContext extends Logging { |} """.stripMargin - val state = SubExprEliminationState(isNull, JavaCode.global(value, expr.dataType)) - exprs.foreach(localSubExprEliminationExprs.put(_, state)) + // Collects other subexpressions from the children. + val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState] + exprs.head.foreach { + case e if localSubExprEliminationExprs.contains(e) => + childrenSubExprs += localSubExprEliminationExprs(e) + case _ => + } + val inputVariables = inputVars.map(_.variableName).mkString(", ") - s"${addNewFunction(fnName, fn)}($inputVariables);" + val code = code"${addNewFunction(fnName, fn)}($inputVariables);" + val state = SubExprEliminationState( + ExprCode(code, isNull, JavaCode.global(value, expr.dataType)), + childrenSubExprs.toSeq) + exprs.foreach(localSubExprEliminationExprs.put(_, state)) } - (splitCodes, localSubExprEliminationExprs, exprCodesNeedEvaluate) + (localSubExprEliminationExprs, exprCodesNeedEvaluate) } else { if (Utils.isTesting) { throw QueryExecutionErrors.failedSplitSubExpressionError(MAX_JVM_METHOD_PARAMS_LENGTH) } else { logInfo(QueryExecutionErrors.failedSplitSubExpressionMsg(MAX_JVM_METHOD_PARAMS_LENGTH)) - (nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty) + (localSubExprEliminationExprsForNonSplit, Seq.empty) } } } else { - (nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty) + (localSubExprEliminationExprsForNonSplit, Seq.empty) } - SubExprCodes(codes, subExprsMap.toMap, exprCodes.flatten) + SubExprCodes(subExprsMap.toMap, exprCodes.flatten) } /** @@ -1174,10 +1249,12 @@ class CodegenContext extends Logging { // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with // at least two nodes) as the cost of doing it is expected to be low. + val subExprCode = s"${addNewFunction(fnName, fn)}($INPUT_ROW);" subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState( - JavaCode.isNullGlobal(isNull), - JavaCode.global(value, expr.dataType)) + ExprCode(code"$subExprCode", + JavaCode.isNullGlobal(isNull), + JavaCode.global(value, expr.dataType))) subExprEliminationExprs ++= e.map(_ -> state).toMap } } @@ -1776,9 +1853,8 @@ object CodeGenerator extends Logging { while (stack.nonEmpty) { stack.pop() match { case e if subExprs.contains(e) => - val SubExprEliminationState(isNull, value) = subExprs(e) - collectLocalVariable(value) - collectLocalVariable(isNull) + collectLocalVariable(subExprs(e).eval.value) + collectLocalVariable(subExprs(e).eval.isNull) case ref: BoundReference if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 44b6aa6b6271..b100554cf240 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -463,15 +463,17 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val add1 = Add(ref, ref) val add2 = Add(add1, add1) val dummy = SubExprEliminationState( - JavaCode.variable("dummy", BooleanType), - JavaCode.variable("dummy", BooleanType)) + ExprCode(EmptyBlock, + JavaCode.variable("dummy", BooleanType), + JavaCode.variable("dummy", BooleanType))) // raw testing of basic functionality { val ctx = new CodegenContext val e = ref.genCode(ctx) // before - ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) + ctx.subExprEliminationExprs += ref -> SubExprEliminationState( + ExprCode(EmptyBlock, e.isNull, e.value)) assert(ctx.subExprEliminationExprs.contains(ref)) // call withSubExprEliminationExprs ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index edddfbe712cb..0c657370f2ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -282,7 +282,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel ctx.withSubExprEliminationExprs(subExprs.states) { exprs.map(_.genCode(ctx)) } - val subExprsCode = subExprs.codes.mkString("\n") + val subExprsCode = ctx.evaluateSubExprEliminationState(subExprs.states.values) val codeBody = s""" public java.lang.Object generate(Object[] references) { @@ -392,6 +392,27 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel Seq(add2, add1, add2, add1, add2, add1, caseWhenExpr)) } + test("SPARK-35829: SubExprEliminationState keeps children sub exprs") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(add1, add1) + + val exprs = Seq(add1, add1, add2, add2) + val ctx = new CodegenContext() + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + + val add2State = subExprs.states(add2) + val add1State = subExprs.states(add1) + assert(add2State.children.contains(add1State)) + + subExprs.states.values.foreach { state => + assert(state.eval.code != EmptyBlock) + } + ctx.evaluateSubExprEliminationState(subExprs.states.values) + subExprs.states.values.foreach { state => + assert(state.eval.code == EmptyBlock) + } + } + test("SPARK-35886: PromotePrecision should not overwrite genCode") { val p = PromotePrecision(Literal(Decimal("10.1"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 1192f02955a5..c97c213cbc21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -258,7 +258,9 @@ case class HashAggregateExec( aggBufferUpdatingExprs: Seq[Seq[Expression]], aggCodeBlocks: Seq[Block], subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = { - val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil } + val exprValsInSubExprs = subExprs.flatMap { case (_, s) => + s.eval.value :: s.eval.isNull :: Nil + } if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) { // `SimpleExprValue`s cannot be used as an input variable for split functions, so // we give up splitting functions if it exists in `subExprs`. @@ -363,7 +365,7 @@ case class HashAggregateExec( bindReferences(updateExprsForOneFunc, inputAttrs) } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = subExprs.codes.mkString("\n") + val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExprsForOneFunc.map(_.genCode(ctx)) @@ -989,7 +991,7 @@ case class HashAggregateExec( bindReferences(updateExprsForOneFunc, inputAttrs) } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = subExprs.codes.mkString("\n") + val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExprsForOneFunc.map(_.genCode(ctx)) @@ -1035,7 +1037,7 @@ case class HashAggregateExec( bindReferences(updateExprsForOneFunc, inputAttrs) } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) - val effectiveCodes = subExprs.codes.mkString("\n") + val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExprsForOneFunc.map(_.genCode(ctx)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index e3c02908d892..7bd4dc7be129 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -72,7 +72,8 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { exprs.map(_.genCode(ctx)) } - (subExprs.codes.mkString("\n"), genVars, subExprs.exprCodesNeedEvaluate) + (ctx.evaluateSubExprEliminationState(subExprs.states.values), genVars, + subExprs.exprCodesNeedEvaluate) } else { ("", exprs.map(_.genCode(ctx)), Seq.empty) }