diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0819c5514904..81ed64675729 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1039,7 +1039,8 @@ class CodegenContext extends Logging { def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { // Create a clear EquivalentExpressions and SubExprEliminationState mapping val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + val localSubExprEliminationExprsForNonSplit = + mutable.HashMap.empty[Expression, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree(_)) @@ -1047,13 +1048,16 @@ class CodegenContext extends Logging { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) - lazy val commonExprVals = commonExprs.map(_.head.genCode(this)) - lazy val nonSplitExprCode = { - commonExprs.zip(commonExprVals).map { case (exprs, eval) => - // Generate the code for this expression tree. - val state = SubExprEliminationState(eval.isNull, eval.value) - exprs.foreach(localSubExprEliminationExprs.put(_, state)) + val nonSplitExprCode = { + commonExprs.map { exprs => + val eval = withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { + val eval = exprs.head.genCode(this) + // Generate the code for this expression tree. + val state = SubExprEliminationState(eval.isNull, eval.value) + exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state)) + Seq(eval) + }.head eval.code.toString } } @@ -1068,11 +1072,19 @@ class CodegenContext extends Logging { }.unzip val splitThreshold = SQLConf.get.methodSplitThreshold - val codes = if (commonExprVals.map(_.code.length).sum > splitThreshold) { + + val (codes, subExprsMap, exprCodes) = if (nonSplitExprCode.map(_.length).sum > splitThreshold) { if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { - commonExprs.zipWithIndex.map { case (exprs, i) => + val localSubExprEliminationExprs = + mutable.HashMap.empty[Expression, SubExprEliminationState] + + val splitCodes = commonExprs.zipWithIndex.map { case (exprs, i) => val expr = exprs.head - val eval = commonExprVals(i) + val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { + Seq(expr.genCode(this)) + }.head + + val value = addMutableState(javaType(expr.dataType), "subExprValue") val isNullLiteral = eval.isNull match { case TrueLiteral | FalseLiteral => true @@ -1090,34 +1102,33 @@ class CodegenContext extends Logging { val inputVars = inputVarsForAllFuncs(i) val argList = inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") - val returnType = javaType(expr.dataType) val fn = s""" - |private $returnType $fnName(${argList.mkString(", ")}) { + |private void $fnName(${argList.mkString(", ")}) { | ${eval.code} | $isNullEvalCode - | return ${eval.value}; + | $value = ${eval.value}; |} """.stripMargin - val value = freshName("subExprValue") - val state = SubExprEliminationState(isNull, JavaCode.variable(value, expr.dataType)) + val state = SubExprEliminationState(isNull, JavaCode.global(value, expr.dataType)) exprs.foreach(localSubExprEliminationExprs.put(_, state)) val inputVariables = inputVars.map(_.variableName).mkString(", ") - s"$returnType $value = ${addNewFunction(fnName, fn)}($inputVariables);" + s"${addNewFunction(fnName, fn)}($inputVariables);" } + (splitCodes, localSubExprEliminationExprs, exprCodesNeedEvaluate) } else { if (Utils.isTesting) { throw QueryExecutionErrors.failedSplitSubExpressionError(MAX_JVM_METHOD_PARAMS_LENGTH) } else { logInfo(QueryExecutionErrors.failedSplitSubExpressionMsg(MAX_JVM_METHOD_PARAMS_LENGTH)) - nonSplitExprCode + (nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty) } } } else { - nonSplitExprCode + (nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty) } - SubExprCodes(codes, localSubExprEliminationExprs.toMap, exprCodesNeedEvaluate.flatten) + SubExprCodes(codes, subExprsMap.toMap, exprCodes.flatten) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6cdcc88e17e0..3e810a453377 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2882,6 +2882,31 @@ class DataFrameSuite extends QueryTest df2.collect() assert(accum.value == 15) } + + test("SPARK-35560: Remove redundant subexpression evaluation in nested subexpressions") { + Seq(1, Int.MaxValue).foreach { splitThreshold => + withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> splitThreshold.toString) { + val accum = sparkContext.longAccumulator("call") + val simpleUDF = udf((s: String) => { + accum.add(1) + s + }) + + // Common exprs: + // 1. simpleUDF($"id") + // 2. functions.length(simpleUDF($"id")) + // We should only evaluate `simpleUDF($"id")` once, i.e. + // subExpr1 = simpleUDF($"id"); + // subExpr2 = functions.length(subExpr1); + val df = spark.range(5).select( + when(functions.length(simpleUDF($"id")) === 1, lower(simpleUDF($"id"))) + .when(functions.length(simpleUDF($"id")) === 0, upper(simpleUDF($"id"))) + .otherwise(simpleUDF($"id")).as("output")) + df.collect() + assert(accum.value == 5) + } + } + } } case class GroupByKey(a: Int, b: Int)