Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils

/**
* This class is used to compute equality of (sub)expression trees. Expressions can be added
* to this class and they subsequently query for expression equality. Expression trees are
* considered equal if for the same input(s), the same result is produced.
*/
class EquivalentExpressions {
class EquivalentExpressions(
skipForShortcutEnable: Boolean = SQLConf.get.subexpressionEliminationSkipForShotcutExpr) {

// For each expression, the set of equivalent expressions.
private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]

Expand Down Expand Up @@ -129,13 +132,27 @@ class EquivalentExpressions {
}
}

private def skipForShortcut(expr: Expression): Expression = {
if (skipForShortcutEnable) {
// The subexpression may not need to eval even if it appears more than once.
// e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true.
expr match {
case and: And => and.left
case or: Or => or.left
case other => other
}
} else {
expr
}
}

// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ConditionalExpression: use its children that will always be evaluated.
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case c: ConditionalExpression => c.alwaysEvaluatedInputs
case other => other.children
case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut)
case other => skipForShortcut(other).children
}

// For some special expressions we cannot just recurse into all of its children, but we can
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,16 @@ object SQLConf {
.checkValue(_ >= 0, "The maximum must not be negative")
.createWithDefault(100)

val SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR =
buildConf("spark.sql.subexpressionElimination.skipForShortcutExpr")
.internal()
.doc("When true, shortcut eliminate subexpression with `AND`, `OR`. " +
"The subexpression may not need to eval even if it appears more than once. " +
"e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true.")
.version("3.5.0")
.booleanConf
.createWithDefault(false)

val CASE_SENSITIVE = buildConf("spark.sql.caseSensitive")
.internal()
.doc("Whether the query analyzer should be case sensitive or not. " +
Expand Down Expand Up @@ -4610,6 +4620,9 @@ class SQLConf extends Serializable with Logging {
def subexpressionEliminationCacheMaxEntries: Int =
getConf(SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES)

def subexpressionEliminationSkipForShotcutExpr: Boolean =
getConf(SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR)

def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD)

def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.catalyst.expressions

import java.util.Properties

import org.apache.spark.{SparkFunSuite, TaskContext, TaskContextImpl}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand Down Expand Up @@ -424,7 +426,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
test("SPARK-38333: PlanExpression expression should skip addExprTree function in Executor") {
try {
// suppose we are in executor
val context1 = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, null, null, cpus = 0)
val context1 = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, new Properties, null, cpus = 0)
TaskContext.setTaskContext(context1)

val equivalence = new EquivalentExpressions
Expand Down Expand Up @@ -465,6 +467,33 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
val cseState = equivalence.getExprState(expr)
assert(hasMatching == cseState.isDefined)
}

test("SPARK-42815: Subexpression elimination support shortcut conditional expression") {
val add = Add(Literal(1), Literal(0))
val equal = EqualTo(add, add)

def checkShortcut(expr: Expression, numCommonExpr: Int): Unit = {
val e1 = If(expr, Literal(1), Literal(2))
val ee1 = new EquivalentExpressions(true)
ee1.addExprTree(e1)
assert(ee1.getCommonSubexpressions.size == numCommonExpr)

val e2 = expr
val ee2 = new EquivalentExpressions(true)
ee2.addExprTree(e2)
assert(ee2.getCommonSubexpressions.size == numCommonExpr)
}

// shortcut right child
checkShortcut(And(Literal(false), equal), 0)
checkShortcut(Or(Literal(true), equal), 0)
checkShortcut(Not(And(Literal(true), equal)), 0)

// always eliminate subexpression for left child
checkShortcut((And(equal, Literal(false))), 1)
checkShortcut(Or(equal, Literal(true)), 1)
checkShortcut(Not(And(equal, Literal(false))), 1)
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down