Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ abstract class Expression extends TreeNode[Expression] {
ctx.subExprEliminationExprs.get(this).map { subExprState =>
// This expression is repeated which means that the code to evaluate it has already been added
// as a function before. In that case, we just re-use it.
ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value)
ExprCode(
ctx.registerComment(this.toString),
subExprState.eval.isNull,
subExprState.eval.value)
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,38 @@ object ExprCode {
/**
* State used for subexpression elimination.
*
* @param isNull A term that holds a boolean value representing whether the expression evaluated
* to null.
* @param value A term for a value of a common sub-expression. Not valid if `isNull`
* is set to `true`.
* @param eval The source code for evaluating the subexpression.
* @param children The sequence of subexpressions as the children expressions. Before
* evaluating this subexpression, we should evaluate all children
* subexpressions first. This is used if we want to selectively evaluate
* particular subexpressions, instead of all at once. In the case, we need
* to make sure we evaluate all children subexpressions too.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your previous PR improves EquivalentExpressions to always return child subexpression first. It seems that PR is not useful after this PR because we track the children explicitly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. We need to return child subexpressions first. So we can make sure child subexpression is codegen-ed and put into the map before parent subexpression. When we want to codegen parent subexpression, it can look up the child subexpression and put it as child of the parent.

Copy link
Member Author

@viirya viirya Jun 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I have a new idea for how to codegen subexpression following child-parent orders without sorting. It is more reliable than the sorting approach. I will open another PR for that.

*/
case class SubExprEliminationState(isNull: ExprValue, value: ExprValue)
case class SubExprEliminationState(
eval: ExprCode,
children: Seq[SubExprEliminationState])

object SubExprEliminationState {
def apply(eval: ExprCode): SubExprEliminationState = {
new SubExprEliminationState(eval, Seq.empty)
}

def apply(
eval: ExprCode,
children: Seq[SubExprEliminationState]): SubExprEliminationState = {
new SubExprEliminationState(eval, children.reverse)
}
}

/**
* Codes and common subexpressions mapping used for subexpression elimination.
*
* @param codes Strings representing the codes that evaluate common subexpressions.
* @param states Foreach expression that is participating in subexpression elimination,
* the state to use.
* @param exprCodesNeedEvaluate Some expression codes that need to be evaluated before
* calling common subexpressions.
*/
case class SubExprCodes(
codes: Seq[String],
states: Map[Expression, SubExprEliminationState],
exprCodesNeedEvaluate: Seq[ExprCode])

Expand Down Expand Up @@ -1030,11 +1044,55 @@ class CodegenContext extends Logging {
}

/**
* Checks and sets up the state and codegen for subexpression elimination. This finds the
* common subexpressions, generates the code snippets that evaluate those expressions and
* populates the mapping of common subexpressions to the generated code snippets. The generated
* code snippets will be returned and should be inserted into generated codes before these
* common subexpressions actually are used first time.
* Evaluates a sequence of `SubExprEliminationState` which represent subexpressions. After
* evaluating a subexpression, this method will clean up the code block to avoid duplicate
* evaluation.
*/
def evaluateSubExprEliminationState(subExprStates: Iterable[SubExprEliminationState]): String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Iterable -> Seq?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All its caller side use Iterable. If changing to Seq here, all callers need to add .toSeq.

val code = new StringBuilder()

subExprStates.foreach { state =>
val currentCode = evaluateSubExprEliminationState(state.children) + "\n" + state.eval.code
code.append(currentCode + "\n")
state.eval.code = EmptyBlock
}

code.toString()
}

/**
* Checks and sets up the state and codegen for subexpression elimination in whole-stage codegen.
*
* This finds the common subexpressions, generates the code snippets that evaluate those
* expressions and populates the mapping of common subexpressions to the generated code snippets.
*
* The generated code snippet for subexpression is wrapped in `SubExprEliminationState`, which
* contains an `ExprCode` and the children `SubExprEliminationState` if any. The `ExprCode`
* includes java source code, result variable name and is-null variable name of the subexpression.
*
* Besides, this also returns a sequences of `ExprCode` which are expression codes that need to
* be evaluated (as their input parameters) before evaluating subexpressions.
*
* To evaluate the returned subexpressions, please call `evaluateSubExprEliminationState` with
* the `SubExprEliminationState`s to be evaluated. During generating the code, it will cleanup
* the states to avoid duplicate evaluation.
*
* The details of subexpression generation:
* 1. Gets subexpression set. See `EquivalentExpressions`.
* 2. Generate code of subexpressions as a whole block of code (non-split case)
* 3. Check if the total length of the above block is larger than the split-threshold. If so,
* try to split it in step 4, otherwise returning the non-split code block.
* 4. Check if parameter lengths of all subexpressions satisfy the JVM limitation, if so,
* try to split, otherwise returning the non-split code block.
* 5. For each subexpression, generating a function and put the code into it. To evaluate the
* subexpression, just call the function.
*
* The explanation of subexpression codegen:
* 1. Wrapping in `withSubExprEliminationExprs` call with current subexpression map. Each
* subexpression may depends on other subexpressions (children). So when generating code
* for subexpressions, we iterate over each subexpression and put the mapping between
* (subexpression -> `SubExprEliminationState`) into the map. So in next subexpression
* evaluation, we can look for generated subexpressions and do replacement.
*/
def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = {
// Create a clear EquivalentExpressions and SubExprEliminationState mapping
Expand All @@ -1049,17 +1107,25 @@ class CodegenContext extends Logging {
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)

val nonSplitExprCode = {
val nonSplitCode = {
val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState]
commonExprs.map { exprs =>
val eval = withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
val eval = exprs.head.genCode(this)
// Generate the code for this expression tree.
val state = SubExprEliminationState(eval.isNull, eval.value)
// Collects other subexpressions from the children.
val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
exprs.head.foreach {
case e if subExprEliminationExprs.contains(e) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add some comments to explain the assumption: this code works because EquivalentExpressions returns child expressions first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW collecting child expressions here looks really inefficient, but I don't have a better idea for now ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. This is not general expression but special (subexpr) ones, so we don't do collecting child expressions in general but in limited range. Except that if you have many subexpr and they are highly nested.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add some comments to explain the assumption: this code works because EquivalentExpressions returns child expressions first.

As I commented before, I plan to remove the sorting. A better idea is to add SubExprEliminationState first into the map (not codegen yet). Then during codegen, we can look at the map to chain children.

childrenSubExprs += subExprEliminationExprs(e)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Is it difficult to add some tests for this new behaviour?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add a few tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added new test.

case _ =>
}
val state = SubExprEliminationState(eval, childrenSubExprs.toSeq)
exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state))
allStates += state
Seq(eval)
}.head
eval.code.toString
}
}
allStates.toSeq
}

// For some operators, they do not require all its child's outputs to be evaluated in advance.
Expand All @@ -1071,14 +1137,13 @@ class CodegenContext extends Logging {
(inputVars.toSeq, exprCodes.toSeq)
}.unzip

val splitThreshold = SQLConf.get.methodSplitThreshold

val (codes, subExprsMap, exprCodes) = if (nonSplitExprCode.map(_.length).sum > splitThreshold) {
val needSplit = nonSplitCode.map(_.eval.code.length).sum > SQLConf.get.methodSplitThreshold
val (subExprsMap, exprCodes) = if (needSplit) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR: so here we repeat the logic of generating SubExprEliminationStates with splitting the code? nonSplitCode is totally wasted?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously it is lazy so we can do non-split conditionally. Now we nestedly generate subExprs so it cannot be lazy now. SubExprEliminationStates are needed to nestedly generate code for them.

if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
val localSubExprEliminationExprs =
mutable.HashMap.empty[Expression, SubExprEliminationState]

val splitCodes = commonExprs.zipWithIndex.map { case (exprs, i) =>
commonExprs.zipWithIndex.foreach { case (exprs, i) =>
val expr = exprs.head
val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) {
Seq(expr.genCode(this))
Expand Down Expand Up @@ -1111,24 +1176,34 @@ class CodegenContext extends Logging {
|}
""".stripMargin

val state = SubExprEliminationState(isNull, JavaCode.global(value, expr.dataType))
exprs.foreach(localSubExprEliminationExprs.put(_, state))
// Collects other subexpressions from the children.
val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
exprs.head.foreach {
case e if localSubExprEliminationExprs.contains(e) =>
childrenSubExprs += localSubExprEliminationExprs(e)
case _ =>
}

val inputVariables = inputVars.map(_.variableName).mkString(", ")
s"${addNewFunction(fnName, fn)}($inputVariables);"
val code = code"${addNewFunction(fnName, fn)}($inputVariables);"
val state = SubExprEliminationState(
ExprCode(code, isNull, JavaCode.global(value, expr.dataType)),
childrenSubExprs.toSeq)
exprs.foreach(localSubExprEliminationExprs.put(_, state))
}
(splitCodes, localSubExprEliminationExprs, exprCodesNeedEvaluate)
(localSubExprEliminationExprs, exprCodesNeedEvaluate)
} else {
if (Utils.isTesting) {
throw QueryExecutionErrors.failedSplitSubExpressionError(MAX_JVM_METHOD_PARAMS_LENGTH)
} else {
logInfo(QueryExecutionErrors.failedSplitSubExpressionMsg(MAX_JVM_METHOD_PARAMS_LENGTH))
(nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty)
(localSubExprEliminationExprsForNonSplit, Seq.empty)
}
}
} else {
(nonSplitExprCode, localSubExprEliminationExprsForNonSplit, Seq.empty)
(localSubExprEliminationExprsForNonSplit, Seq.empty)
}
SubExprCodes(codes, subExprsMap.toMap, exprCodes.flatten)
SubExprCodes(subExprsMap.toMap, exprCodes.flatten)
}

/**
Expand Down Expand Up @@ -1174,10 +1249,12 @@ class CodegenContext extends Logging {
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
// at least two nodes) as the cost of doing it is expected to be low.

val subExprCode = s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shall we use subexprFunctions += subExprCode here? otherwise we are calling addNewFunction twice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, as the functions in class is a map, it will overwrite. But yes, we should use subExprCode. Let me submit a followup.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val state = SubExprEliminationState(
JavaCode.isNullGlobal(isNull),
JavaCode.global(value, expr.dataType))
ExprCode(code"$subExprCode",
JavaCode.isNullGlobal(isNull),
JavaCode.global(value, expr.dataType)))
subExprEliminationExprs ++= e.map(_ -> state).toMap
}
}
Expand Down Expand Up @@ -1776,9 +1853,8 @@ object CodeGenerator extends Logging {
while (stack.nonEmpty) {
stack.pop() match {
case e if subExprs.contains(e) =>
val SubExprEliminationState(isNull, value) = subExprs(e)
collectLocalVariable(value)
collectLocalVariable(isNull)
collectLocalVariable(subExprs(e).eval.value)
collectLocalVariable(subExprs(e).eval.isNull)

case ref: BoundReference if ctx.currentVars != null &&
ctx.currentVars(ref.ordinal) != null =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,15 +463,17 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
val add1 = Add(ref, ref)
val add2 = Add(add1, add1)
val dummy = SubExprEliminationState(
JavaCode.variable("dummy", BooleanType),
JavaCode.variable("dummy", BooleanType))
ExprCode(EmptyBlock,
JavaCode.variable("dummy", BooleanType),
JavaCode.variable("dummy", BooleanType)))

// raw testing of basic functionality
{
val ctx = new CodegenContext
val e = ref.genCode(ctx)
// before
ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value)
ctx.subExprEliminationExprs += ref -> SubExprEliminationState(
ExprCode(EmptyBlock, e.isNull, e.value))
assert(ctx.subExprEliminationExprs.contains(ref))
// call withSubExprEliminationExprs
ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
ctx.withSubExprEliminationExprs(subExprs.states) {
exprs.map(_.genCode(ctx))
}
val subExprsCode = subExprs.codes.mkString("\n")
val subExprsCode = ctx.evaluateSubExprEliminationState(subExprs.states.values)

val codeBody = s"""
public java.lang.Object generate(Object[] references) {
Expand Down Expand Up @@ -392,6 +392,27 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
Seq(add2, add1, add2, add1, add2, add1, caseWhenExpr))
}

test("SPARK-35829: SubExprEliminationState keeps children sub exprs") {
val add1 = Add(Literal(1), Literal(2))
val add2 = Add(add1, add1)

val exprs = Seq(add1, add1, add2, add2)
val ctx = new CodegenContext()
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)

val add2State = subExprs.states(add2)
val add1State = subExprs.states(add1)
assert(add2State.children.contains(add1State))

subExprs.states.values.foreach { state =>
assert(state.eval.code != EmptyBlock)
}
ctx.evaluateSubExprEliminationState(subExprs.states.values)
subExprs.states.values.foreach { state =>
assert(state.eval.code == EmptyBlock)
}
}

test("SPARK-35886: PromotePrecision should not overwrite genCode") {
val p = PromotePrecision(Literal(Decimal("10.1")))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,9 @@ case class HashAggregateExec(
aggBufferUpdatingExprs: Seq[Seq[Expression]],
aggCodeBlocks: Seq[Block],
subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = {
val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil }
val exprValsInSubExprs = subExprs.flatMap { case (_, s) =>
s.eval.value :: s.eval.isNull :: Nil
}
if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
// `SimpleExprValue`s cannot be used as an input variable for split functions, so
// we give up splitting functions if it exists in `subExprs`.
Expand Down Expand Up @@ -363,7 +365,7 @@ case class HashAggregateExec(
bindReferences(updateExprsForOneFunc, inputAttrs)
}
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n")
val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
Expand Down Expand Up @@ -989,7 +991,7 @@ case class HashAggregateExec(
bindReferences(updateExprsForOneFunc, inputAttrs)
}
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n")
val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
Expand Down Expand Up @@ -1035,7 +1037,7 @@ case class HashAggregateExec(
bindReferences(updateExprsForOneFunc, inputAttrs)
}
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n")
val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values)
val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc =>
ctx.withSubExprEliminationExprs(subExprs.states) {
boundUpdateExprsForOneFunc.map(_.genCode(ctx))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
val genVars = ctx.withSubExprEliminationExprs(subExprs.states) {
exprs.map(_.genCode(ctx))
}
(subExprs.codes.mkString("\n"), genVars, subExprs.exprCodesNeedEvaluate)
(ctx.evaluateSubExprEliminationState(subExprs.states.values), genVars,
subExprs.exprCodesNeedEvaluate)
} else {
("", exprs.map(_.genCode(ctx)), Seq.empty)
}
Expand Down