Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -408,29 +408,11 @@ class CodegenContext extends Logging {
partitionInitializationStatements.mkString("\n")
}

/**
* Holds expressions that are equivalent. Used to perform subexpression elimination
* during codegen.
*
* For expressions that appear more than once, generate additional code to prevent
* recomputing the value.
*
* For example, consider two expression generated from this SQL statement:
* SELECT (col1 + col2), (col1 + col2) / col3.
*
* equivalentExpressions will match the tree containing `col1 + col2` and it will only
* be evaluated once.
*/
private val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions

// Foreach expression that is participating in subexpression elimination, the state to use.
// Visible for testing.
private[expressions] var subExprEliminationExprs =
Map.empty[ExpressionEquals, SubExprEliminationState]

// The collection of sub-expression result resetting methods that need to be called on each row.
private val subexprFunctions = mutable.ArrayBuffer.empty[String]

val outerClassName = "OuterClass"

/**
Expand Down Expand Up @@ -1021,9 +1003,8 @@ class CodegenContext extends Logging {
* Returns the code for subexpression elimination after splitting it if necessary.
*/
def subexprFunctionsCode: String = {
// Whole-stage codegen's subexpression elimination is handled in another code path
assert(currentVars == null || subexprFunctions.isEmpty)
splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW))
val subExprCode = evaluateSubExprEliminationState(subExprEliminationExprs.values)
splitExpressionsWithCurrentInputs(subExprCode, "subexprFunc_split")
}

/**
Expand All @@ -1048,20 +1029,17 @@ class CodegenContext extends Logging {
* evaluating a subexpression, this method will clean up the code block to avoid duplicate
* evaluation.
*/
def evaluateSubExprEliminationState(subExprStates: Iterable[SubExprEliminationState]): String = {
val code = new StringBuilder()

subExprStates.foreach { state =>
val currentCode = evaluateSubExprEliminationState(state.children) + "\n" + state.eval.code
code.append(currentCode + "\n")
def evaluateSubExprEliminationState(
subExprStates: Iterable[SubExprEliminationState]): Seq[String] = {
subExprStates.flatMap { state =>
val currentCode = evaluateSubExprEliminationState(state.children) :+ state.eval.code.toString
state.eval.code = EmptyBlock
}

code.toString()
currentCode
}.toSeq
}

/**
* Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen.
* Checks and sets up the state and codegen for subexpression elimination.
*
* This finds the common subexpressions, generates the code snippets that evaluate those
* expressions and populates the mapping of common subexpressions to the generated code snippets.
Expand Down Expand Up @@ -1094,10 +1072,10 @@ class CodegenContext extends Logging {
* (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression
* evaluation, we can look for generated subexpressions and do replacement.
*/
def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = {
def subexpressionElimination(expressions: Seq[Expression]): SubExprCodes = {
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
val localSubExprEliminationExprsForNonSplit =
val localSubExprEliminationExprs =
mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState]

// Add each expression tree and compute the common subexpressions.
Expand All @@ -1110,8 +1088,28 @@ class CodegenContext extends Logging {
val nonSplitCode = {
val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState]
commonExprs.map { expr =>
withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) {
val eval = expr.genCode(this)

val value = addMutableState(javaType(expr.dataType), "subExprValue")

val isNullLiteral = eval.isNull match {
case TrueLiteral | FalseLiteral => true
case _ => false
}
val (isNull, isNullEvalCode) = if (!isNullLiteral) {
val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
(JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};")
} else {
(eval.isNull, "")
}

val code = code"""
|${eval.code}
|$isNullEvalCode
|$value = ${eval.value};
"""

// Collects other subexpressions from the children.
val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
expr.foreach { e =>
Expand All @@ -1120,8 +1118,10 @@ class CodegenContext extends Logging {
case _ =>
}
}
val state = SubExprEliminationState(eval, childrenSubExprs.toSeq)
localSubExprEliminationExprsForNonSplit.put(ExpressionEquals(expr), state)
val state = SubExprEliminationState(
ExprCode(code, isNull, JavaCode.global(value, expr.dataType)),
childrenSubExprs.toSeq)
localSubExprEliminationExprs.put(ExpressionEquals(expr), state)
allStates += state
Seq(eval)
}
Expand All @@ -1141,38 +1141,18 @@ class CodegenContext extends Logging {
val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold
val (subExprsMap, exprCodes) = if (needSplit) {
if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
val localSubExprEliminationExprs =
mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState]

commonExprs.zipWithIndex.foreach { case (expr, i) =>
val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) {
Seq(expr.genCode(this))
}.head

val value = addMutableState(javaType(expr.dataType), "subExprValue")

val isNullLiteral = eval.isNull match {
case TrueLiteral | FalseLiteral => true
case _ => false
}
val (isNull, isNullEvalCode) = if (!isNullLiteral) {
val v = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
(JavaCode.isNullGlobal(v), s"$v = ${eval.isNull};")
} else {
(eval.isNull, "")
}

// Generate the code for this expression tree and wrap it in a function.
val fnName = freshName("subExpr")
val inputVars = inputVarsForAllFuncs(i)
val argList =
inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}")
val subExprState = localSubExprEliminationExprs.remove(ExpressionEquals(expr)).get
val fn =
s"""
|private void $fnName(${argList.mkString(", ")}) {
| ${eval.code}
| $isNullEvalCode
| $value = ${eval.value};
| ${subExprState.eval.code}
|}
""".stripMargin

Expand All @@ -1188,7 +1168,7 @@ class CodegenContext extends Logging {
val inputVariables = inputVars.map(_.variableName).mkString(", ")
val code = code"${addNewFunction(fnName, fn)}($inputVariables);"
val state = SubExprEliminationState(
ExprCode(code, isNull, JavaCode.global(value, expr.dataType)),
subExprState.eval.copy(code = code),
childrenSubExprs.toSeq)
localSubExprEliminationExprs.put(ExpressionEquals(expr), state)
}
Expand All @@ -1201,67 +1181,16 @@ class CodegenContext extends Logging {
throw new IllegalStateException(errMsg)
} else {
logInfo(errMsg)
(localSubExprEliminationExprsForNonSplit, Seq.empty)
(localSubExprEliminationExprs, Seq.empty)
}
}
} else {
(localSubExprEliminationExprsForNonSplit, Seq.empty)
(localSubExprEliminationExprs, Seq.empty)
}
subExprsMap.foreach(subExprEliminationExprs += _)
SubExprCodes(subExprsMap.toMap, exprCodes.flatten)
}

/**
* Checks and sets up the state and codegen for subexpression elimination. This finds the
* common subexpressions, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
*/
private def subexpressionElimination(expressions: Seq[Expression]): Unit = {
// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_))

// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getCommonSubexpressions
commonExprs.foreach { expr =>
val fnName = freshName("subExpr")
val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
val value = addMutableState(javaType(expr.dataType), "subExprValue")

// Generate the code for this expression tree and wrap it in a function.
val eval = expr.genCode(this)
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
| ${eval.code}
| $isNull = ${eval.isNull};
| $value = ${eval.value};
|}
""".stripMargin

// Add a state and a mapping of the common subexpressions that are associate with this
// state. Adding this expression to subExprEliminationExprMap means it will call `fn`
// when it is code generated. This decision should be a cost based one.
//
// The cost of doing subexpression elimination is:
// 1. Extra function call, although this is probably *good* as the JIT can decide to
// inline or not.
// The benefit doing subexpression elimination is:
// 1. Running the expression logic. Even for a simple expression, it is likely more than 3
// above.
// 2. Less code.
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
// at least two nodes) as the cost of doing it is expected to be low.

val subExprCode = s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
subexprFunctions += subExprCode
val state = SubExprEliminationState(
ExprCode(code"$subExprCode",
JavaCode.isNullGlobal(isNull),
JavaCode.global(value, expr.dataType)))
subExprEliminationExprs += ExpressionEquals(expr) -> state
}
}

/**
* Generates code for expressions. If doSubexpressionElimination is true, subexpression
* elimination will be performed. Subexpression elimination assumes that the code for each
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
ExprCode(TrueLiteral, oneVar),
ExprCode(TrueLiteral, twoVar))

val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
ctx.withSubExprEliminationExprs(subExprs.states) {
exprs.map(_.genCode(ctx))
}
val subExprsCode = ctx.evaluateSubExprEliminationState(subExprs.states.values)
ctx.subexpressionElimination(exprs)
val subExprsCode = ctx.subexprFunctionsCode

val codeBody = s"""
public java.lang.Object generate(Object[] references) {
Expand Down Expand Up @@ -408,7 +405,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel

val exprs = Seq(add1, add1, add2, add2)
val ctx = new CodegenContext()
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
val subExprs = ctx.subexpressionElimination(exprs)

val add2State = subExprs.states(ExpressionEquals(add2))
val add1State = subExprs.states(ExpressionEquals(add1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,14 @@ trait AggregateCodegenSupport
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
bindReferences(updateExprsForOneFunc, inputAttrs)
}
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val subExprs = if (conf.subexpressionEliminationEnabled) {
ctx.subexpressionElimination(boundUpdateExprs.flatten).states
} else {
Map.empty[ExpressionEquals, SubExprEliminationState]
}
val effectiveCodes = ctx.subexprFunctionsCode
val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
}
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
}

val aggNames = functions.map(_.prettyName)
Expand Down Expand Up @@ -256,11 +258,11 @@ trait AggregateCodegenSupport
boundUpdateExprs: Seq[Seq[Expression]],
aggNames: Seq[String],
aggCodeBlocks: Seq[Block],
subExprs: SubExprCodes): String = {
subExprs: Map[ExpressionEquals, SubExprEliminationState]): String = {
val aggCodes = if (conf.codegenSplitAggregateFunc &&
aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
val maybeSplitCodes = splitAggregateExpressions(
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs)

maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -728,12 +728,14 @@ case class HashAggregateExec(
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
bindReferences(updateExprsForOneFunc, inputAttrs)
}
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val subExprs = if (conf.subexpressionEliminationEnabled) {
ctx.subexpressionElimination(boundUpdateExprs.flatten).states
} else {
Map.empty[ExpressionEquals, SubExprEliminationState]
}
val effectiveCodes = ctx.subexprFunctionsCode
val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
}
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
}

val aggCodeBlocks = updateExprs.indices.map { i =>
Expand Down Expand Up @@ -774,12 +776,14 @@ case class HashAggregateExec(
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
bindReferences(updateExprsForOneFunc, inputAttrs)
}
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val subExprs = if (conf.subexpressionEliminationEnabled) {
ctx.subexpressionElimination(boundUpdateExprs.flatten).states
} else {
Map.empty[ExpressionEquals, SubExprEliminationState]
}
val effectiveCodes = ctx.subexprFunctionsCode
val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
}
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
}

val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case (fastRowEvalsForOneFunc, i) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,13 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val exprs = bindReferences[Expression](projectList, child.output)
val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) {
// subexpression elimination
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
val genVars = ctx.withSubExprEliminationExprs(subExprs.states) {
exprs.map(_.genCode(ctx))
}
(ctx.evaluateSubExprEliminationState(subExprs.states.values), genVars,
subExprs.exprCodesNeedEvaluate)
val localValInputs = if (conf.subexpressionEliminationEnabled) {
ctx.subexpressionElimination(exprs).exprCodesNeedEvaluate
} else {
("", exprs.map(_.genCode(ctx)), Seq.empty)
Seq.empty
}
val resultVars = exprs.map(_.genCode(ctx))
val subExprsCode = ctx.subexprFunctionsCode

// Evaluation of non-deterministic expressions can't be deferred.
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
Expand Down