Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1039,21 +1039,25 @@ class CodegenContext extends Logging {
def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = {
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
val localSubExprEliminationExprsForNonSplit =
mutable.HashMap.empty[Expression, SubExprEliminationState]

// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_))

// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
lazy val commonExprVals = commonExprs.map(_.head.genCode(this))

lazy val nonSplitExprCode = {
commonExprs.zip(commonExprVals).map { case (exprs, eval) =>
// Generate the code for this expression tree.
val state = SubExprEliminationState(eval.isNull, eval.value)
exprs.foreach(localSubExprEliminationExprs.put(_, state))
val nonSplitExprCode = {
commonExprs.map { exprs =>
val eval = withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this recursive call of withSubExprEliminationExprs give us?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each set of common expressions, withSubExprEliminationExprs only called once so I think it is not actually a recursive call?

withSubExprEliminationExprs takes the given map used for subexpression elimination to replace common expression during expression codegen in the closure. It returns evaluated expression code (value/isNull/code).

For the two subexpressions as example:

  1. simpleUDF($"id")
  2. functions.length(simpleUDF($"id"))

1st round withSubExprEliminationExprs:

The map is empty.
Gen code for simpleUDF($"id").
Put it into the map => (simpleUDF($"id") -> gen-ed code)

2nd round withSubExprEliminationExprs:

Gen code for functions.length(simpleUDF($"id")).
Looking at the map and replace common expression simpleUDF($"id") as gen-ed code.
Put it into the map => (simpleUDF($"id") -> gen-ed code, functions.length(simpleUDF($"id")) -> gen-ed code)

The map will be used later for subexpression elimination.

val eval = exprs.head.genCode(this)
// Generate the code for this expression tree.
val state = SubExprEliminationState(eval.isNull, eval.value)
exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state))
Seq(eval)
}.head
eval.code.toString
}
}
Expand All @@ -1068,11 +1072,19 @@ class CodegenContext extends Logging {
}.unzip

val splitThreshold = SQLConf.get.methodSplitThreshold
val codes = if (commonExprVals.map(_.code.length).sum > splitThreshold) {

val (codes, subExprsMap, exprCodes) = if (nonSplitExprCode.map(_.length).sum > splitThreshold) {
if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
commonExprs.zipWithIndex.map { case (exprs, i) =>
val localSubExprEliminationExprs =
mutable.HashMap.empty[Expression, SubExprEliminationState]

val splitCodes = commonExprs.zipWithIndex.map { case (exprs, i) =>
val expr = exprs.head
val eval = commonExprVals(i)
val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) {
Seq(expr.genCode(this))
}.head

val value = addMutableState(javaType(expr.dataType), "subExprValue")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this have to be a mutable state now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the example in the description to explain. For the two subexpressions:

  1. simpleUDF($"id")
  2. functions.length(simpleUDF($"id"))

Previously we evaluate them independently, i.e.,

String subExpr1 = simpleUDF($"id");
Int subExpr2 = functions.length(simpleUDF($"id"));

Now we remove redundant evaluation of nested subexpressions:

String subExpr1 = simpleUDF($"id");
Int subExpr2 = functions.length(subExpr1);

If we need to split the functions, when we evaluate functions.length, it needs access of subExpr1. We have two choices. One is to add subExpr1 to the function parameter list of the split function for functions.length. Another one is to use mutable state.

To add it to parameter list will complicate the way we compute parameter length. That's said we need to link nested subexpression relations and get the correct parameters. Seems to me it is not worth doing that.

Currently I choose the simpler approach that is to use mutable state.


val isNullLiteral = eval.isNull match {
case TrueLiteral | FalseLiteral => true
Expand All @@ -1090,34 +1102,33 @@ class CodegenContext extends Logging {
val inputVars = inputVarsForAllFuncs(i)
val argList =
inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}")
val returnType = javaType(expr.dataType)
val fn =
s"""
|private $returnType $fnName(${argList.mkString(", ")}) {
|private void $fnName(${argList.mkString(", ")}) {
| ${eval.code}
| $isNullEvalCode
| return ${eval.value};
| $value = ${eval.value};
|}
""".stripMargin

val value = freshName("subExprValue")
val state = SubExprEliminationState(isNull, JavaCode.variable(value, expr.dataType))
val state = SubExprEliminationState(isNull, JavaCode.global(value, expr.dataType))
exprs.foreach(localSubExprEliminationExprs.put(_, state))
val inputVariables = inputVars.map(_.variableName).mkString(", ")
s"$returnType $value = ${addNewFunction(fnName, fn)}($inputVariables);"
s"${addNewFunction(fnName, fn)}($inputVariables);"
}
(splitCodes, localSubExprEliminationExprs, exprCodesNeedEvaluate)
} else {
if (Utils.isTesting) {
throw QueryExecutionErrors.failedSplitSubExpressionError(MAX_JVM_METHOD_PARAMS_LENGTH)
} else {
logInfo(QueryExecutionErrors.failedSplitSubExpressionMsg(MAX_JVM_METHOD_PARAMS_LENGTH))
nonSplitExprCode
(nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty)
}
}
} else {
nonSplitExprCode
(nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty)
}
SubExprCodes(codes, localSubExprEliminationExprs.toMap, exprCodesNeedEvaluate.flatten)
SubExprCodes(codes, subExprsMap.toMap, exprCodes.flatten)
}

/**
Expand Down
25 changes: 25 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2882,6 +2882,31 @@ class DataFrameSuite extends QueryTest
df2.collect()
assert(accum.value == 15)
}

test("SPARK-35560: Remove redundant subexpression evaluation in nested subexpressions") {
Seq(1, Int.MaxValue).foreach { splitThreshold =>
withSQLConf(SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> splitThreshold.toString) {
val accum = sparkContext.longAccumulator("call")
val simpleUDF = udf((s: String) => {
accum.add(1)
s
})

// Common exprs:
// 1. simpleUDF($"id")
// 2. functions.length(simpleUDF($"id"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: What if a tree has more deeply-nested common exprs? The current logic can work well? e.g., I thought it like this;

        // subExpr1 = simpleUDF($"id");
        // subExpr2 = functions.length(subExpr1);
        // subExpr3 = functions.xxxx(subExpr2);
        // subExpr4 = ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is actually the cases this logic to deal with. Previous common expression gen-ed codes will put into the map. The code generator looks up into the map when generating code for later common expressions to replace the semantic-equal expression with gen-ed code value.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

// We should only evaluate `simpleUDF($"id")` once, i.e.
// subExpr1 = simpleUDF($"id");
// subExpr2 = functions.length(subExpr1);
val df = spark.range(5).select(
when(functions.length(simpleUDF($"id")) === 1, lower(simpleUDF($"id")))
.when(functions.length(simpleUDF($"id")) === 0, upper(simpleUDF($"id")))
.otherwise(simpleUDF($"id")).as("output"))
df.collect()
assert(accum.value == 5)
}
}
}
}

case class GroupByKey(a: Int, b: Int)
Expand Down