-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-35560][SQL] Remove redundant subexpression evaluation in nested subexpressions #32699
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1039,21 +1039,25 @@ class CodegenContext extends Logging { | |
| def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { | ||
| // Create a clear EquivalentExpressions and SubExprEliminationState mapping | ||
| val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions | ||
| val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] | ||
| val localSubExprEliminationExprsForNonSplit = | ||
| mutable.HashMap.empty[Expression, SubExprEliminationState] | ||
|
|
||
| // Add each expression tree and compute the common subexpressions. | ||
| expressions.foreach(equivalentExpressions.addExprTree(_)) | ||
|
|
||
| // Get all the expressions that appear at least twice and set up the state for subexpression | ||
| // elimination. | ||
| val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) | ||
| lazy val commonExprVals = commonExprs.map(_.head.genCode(this)) | ||
|
|
||
| lazy val nonSplitExprCode = { | ||
| commonExprs.zip(commonExprVals).map { case (exprs, eval) => | ||
| // Generate the code for this expression tree. | ||
| val state = SubExprEliminationState(eval.isNull, eval.value) | ||
| exprs.foreach(localSubExprEliminationExprs.put(_, state)) | ||
| val nonSplitExprCode = { | ||
| commonExprs.map { exprs => | ||
| val eval = withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) { | ||
| val eval = exprs.head.genCode(this) | ||
| // Generate the code for this expression tree. | ||
| val state = SubExprEliminationState(eval.isNull, eval.value) | ||
| exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state)) | ||
| Seq(eval) | ||
| }.head | ||
| eval.code.toString | ||
| } | ||
| } | ||
|
|
@@ -1068,11 +1072,19 @@ class CodegenContext extends Logging { | |
| }.unzip | ||
|
|
||
| val splitThreshold = SQLConf.get.methodSplitThreshold | ||
| val codes = if (commonExprVals.map(_.code.length).sum > splitThreshold) { | ||
|
|
||
| val (codes, subExprsMap, exprCodes) = if (nonSplitExprCode.map(_.length).sum > splitThreshold) { | ||
| if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { | ||
| commonExprs.zipWithIndex.map { case (exprs, i) => | ||
| val localSubExprEliminationExprs = | ||
| mutable.HashMap.empty[Expression, SubExprEliminationState] | ||
|
|
||
| val splitCodes = commonExprs.zipWithIndex.map { case (exprs, i) => | ||
| val expr = exprs.head | ||
| val eval = commonExprVals(i) | ||
| val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) { | ||
| Seq(expr.genCode(this)) | ||
| }.head | ||
|
|
||
| val value = addMutableState(javaType(expr.dataType), "subExprValue") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does this have to be a mutable state now?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the example in the description to explain. For the two subexpressions:
Previously we evaluate them independently, i.e., Now we remove redundant evaluation of nested subexpressions: If we need to split the functions, when we evaluate To add it to parameter list will complicate the way we compute parameter length. That's said we need to link nested subexpression relations and get the correct parameters. Seems to me it is not worth doing that. Currently I choose the simpler approach that is to use mutable state. |
||
|
|
||
| val isNullLiteral = eval.isNull match { | ||
| case TrueLiteral | FalseLiteral => true | ||
|
|
@@ -1090,34 +1102,33 @@ class CodegenContext extends Logging { | |
| val inputVars = inputVarsForAllFuncs(i) | ||
| val argList = | ||
| inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") | ||
| val returnType = javaType(expr.dataType) | ||
| val fn = | ||
| s""" | ||
| |private $returnType $fnName(${argList.mkString(", ")}) { | ||
| |private void $fnName(${argList.mkString(", ")}) { | ||
| | ${eval.code} | ||
| | $isNullEvalCode | ||
| | return ${eval.value}; | ||
| | $value = ${eval.value}; | ||
| |} | ||
| """.stripMargin | ||
|
|
||
| val value = freshName("subExprValue") | ||
| val state = SubExprEliminationState(isNull, JavaCode.variable(value, expr.dataType)) | ||
| val state = SubExprEliminationState(isNull, JavaCode.global(value, expr.dataType)) | ||
| exprs.foreach(localSubExprEliminationExprs.put(_, state)) | ||
| val inputVariables = inputVars.map(_.variableName).mkString(", ") | ||
| s"$returnType $value = ${addNewFunction(fnName, fn)}($inputVariables);" | ||
| s"${addNewFunction(fnName, fn)}($inputVariables);" | ||
| } | ||
| (splitCodes, localSubExprEliminationExprs, exprCodesNeedEvaluate) | ||
| } else { | ||
| if (Utils.isTesting) { | ||
| throw QueryExecutionErrors.failedSplitSubExpressionError(MAX_JVM_METHOD_PARAMS_LENGTH) | ||
| } else { | ||
| logInfo(QueryExecutionErrors.failedSplitSubExpressionMsg(MAX_JVM_METHOD_PARAMS_LENGTH)) | ||
| nonSplitExprCode | ||
| (nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty) | ||
| } | ||
| } | ||
| } else { | ||
| nonSplitExprCode | ||
| (nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty) | ||
| } | ||
| SubExprCodes(codes, localSubExprEliminationExprs.toMap, exprCodesNeedEvaluate.flatten) | ||
| SubExprCodes(codes, subExprsMap.toMap, exprCodes.flatten) | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2882,6 +2882,31 @@ class DataFrameSuite extends QueryTest | |
| df2.collect() | ||
| assert(accum.value == 15) | ||
| } | ||
|
|
||
| test("SPARK-35560: Remove redundant subexpression evaluation in nested subexpressions") { | ||
| Seq(1, Int.MaxValue).foreach { splitThreshold => | ||
| withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> splitThreshold.toString) { | ||
| val accum = sparkContext.longAccumulator("call") | ||
| val simpleUDF = udf((s: String) => { | ||
| accum.add(1) | ||
| s | ||
| }) | ||
|
|
||
| // Common exprs: | ||
| // 1. simpleUDF($"id") | ||
| // 2. functions.length(simpleUDF($"id")) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: What if a tree has more deeply-nested common exprs? The current logic can work well? e.g., I thought it like this;
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is actually the cases this logic to deal with. Previous common expression gen-ed codes will put into the map. The code generator looks up into the map when generating code for later common expressions to replace the semantic-equal expression with gen-ed code value.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice |
||
| // We should only evaluate `simpleUDF($"id")` once, i.e. | ||
| // subExpr1 = simpleUDF($"id"); | ||
| // subExpr2 = functions.length(subExpr1); | ||
| val df = spark.range(5).select( | ||
| when(functions.length(simpleUDF($"id")) === 1, lower(simpleUDF($"id"))) | ||
| .when(functions.length(simpleUDF($"id")) === 0, upper(simpleUDF($"id"))) | ||
| .otherwise(simpleUDF($"id")).as("output")) | ||
| df.collect() | ||
| assert(accum.value == 5) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| case class GroupByKey(a: Int, b: Int) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this recursive call of
withSubExprEliminationExprsgive us?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For each set of common expressions,
withSubExprEliminationExprsonly called once so I think it is not actually a recursive call?withSubExprEliminationExprstakes the given map used for subexpression elimination to replace common expression during expression codegen in the closure. It returns evaluated expression code (value/isNull/code).For the two subexpressions as example:
simpleUDF($"id")functions.length(simpleUDF($"id"))1st round
withSubExprEliminationExprs:The map is empty.
Gen code for
simpleUDF($"id").Put it into the map => (
simpleUDF($"id")-> gen-ed code)2nd round
withSubExprEliminationExprs:Gen code for
functions.length(simpleUDF($"id")).Looking at the map and replace common expression
simpleUDF($"id")as gen-ed code.Put it into the map => (
simpleUDF($"id")-> gen-ed code,functions.length(simpleUDF($"id"))-> gen-ed code)The map will be used later for subexpression elimination.