Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,22 @@ class EquivalentExpressions(
private def updateExprInMap(
expr: Expression,
map: mutable.HashMap[ExpressionEquals, ExpressionStats],
useCount: Int = 1): Boolean = {
useCount: Int = 1,
conditional: Boolean = false): Boolean = {
if (expr.deterministic) {
val wrapper = ExpressionEquals(expr)
map.get(wrapper) match {
case Some(stats) =>
stats.useCount += useCount
if (stats.useCount > 0) {
val count = if (conditional) {
stats.conditionalUseCount += useCount
stats.conditionalUseCount
} else {
stats.useCount += useCount
stats.useCount
}
if (count > 0) {
true
} else if (stats.useCount == 0) {
} else if (count == 0) {
map -= wrapper
false
} else {
Expand All @@ -79,7 +86,12 @@ class EquivalentExpressions(
}
case _ =>
if (useCount > 0) {
map.put(wrapper, ExpressionStats(expr)(useCount))
val stats = if (conditional) {
ExpressionStats(expr)(useCount = 0, conditionalUseCount = useCount)
} else {
ExpressionStats(expr)(useCount)
}
map.put(wrapper, stats)
}
false
}
Expand All @@ -89,44 +101,47 @@ class EquivalentExpressions(
}

/**
* Adds or removes only expressions which are common in each of given expressions, in a recursive
* Returns a list of expressions which are common in each of given expressions, in a recursive
* way.
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common
* expression `(c + 1)` will be added into `equivalenceMap`.
* expression `(c + 1)` will be returned.
*
* Note that as we don't know in advance if any child node of an expression will be common across
* all given expressions, we compute local equivalence maps for all given expressions and filter
* only the common nodes.
* Those common nodes are then removed from the local map and added to the final map of
* Those common nodes are then removed from the local map and added to the final list of
* expressions.
*
* Conditional expressions are not considered because we are simply looking for expressions
* evaluated once in each parent expression.
*/
private def updateCommonExprs(
exprs: Seq[Expression],
map: mutable.HashMap[ExpressionEquals, ExpressionStats],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we update the doc of this method? no equivalenceMap in this method now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep good call, haven't kept up with with some of the docs

useCount: Int): Unit = {
private def getCommonExprs(exprs: Seq[Expression]): Seq[ExpressionEquals] = {
assert(exprs.length > 1)
var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
updateExprTree(exprs.head, localEquivalenceMap)
updateExprTree(exprs.head, localEquivalenceMap, conditionalsEnabled = false)

exprs.tail.foreach { expr =>
val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
updateExprTree(expr, otherLocalEquivalenceMap)
updateExprTree(expr, otherLocalEquivalenceMap, conditionalsEnabled = false)
localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
otherLocalEquivalenceMap.contains(key)
}
}

val commonExpressions = mutable.ListBuffer.empty[ExpressionEquals]

// Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`.
// The remaining highest expression in `localEquivalenceMap` is also common expression so loop
// until `localEquivalenceMap` is not empty.
var statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
while (statsOption.nonEmpty) {
val stats = statsOption.get
updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount)
updateExprTree(stats.expr, map, useCount)
updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount, conditionalsEnabled = false)
commonExpressions += ExpressionEquals(stats.expr)

statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
}
commonExpressions.toSeq
}

private def skipForShortcut(expr: Expression): Expression = {
Expand All @@ -143,21 +158,17 @@ class EquivalentExpressions(
}
}

// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ConditionalExpression: use its children that will always be evaluated.
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut)
case other => skipForShortcut(other).children
}

// For some special expressions we cannot just recurse into all of its children, but we can
// recursively add the common expressions shared between all of its children.
private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
case _: CodegenFallback => Nil
case c: ConditionalExpression => c.branchGroups
case _ => Nil
/**
* There are some expressions that need special handling:
* 1. CodegenFallback: It's children will not be used to generate code (call eval() instead).
* 2. ConditionalExpression: use its children that will always be evaluated.
*/
private def childrenToRecurse(expr: Expression): RecurseChildren = expr match {
case _: CodegenFallback => RecurseChildren(Nil)
case c: ConditionalExpression =>
RecurseChildren(c.alwaysEvaluatedInputs.map(skipForShortcut), c.branchGroups,
c.conditionallyEvaluatedInputs)
case other => RecurseChildren(skipForShortcut(other).children)
}

private def supportedExpression(e: Expression): Boolean = {
Expand All @@ -184,13 +195,48 @@ class EquivalentExpressions(
private def updateExprTree(
expr: Expression,
map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap,
useCount: Int = 1): Unit = {
val skip = useCount == 0 || expr.isInstanceOf[LeafExpression]
useCount: Int = 1,
conditionalsEnabled: Boolean = SQLConf.get.subexpressionEliminationConditionalsEnabled,
conditional: Boolean = false,
skipExpressions: Set[ExpressionEquals] = Set.empty[ExpressionEquals]
): Unit = {
val skip = useCount == 0 ||
expr.isInstanceOf[LeafExpression] ||
skipExpressions.contains(ExpressionEquals(expr))

if (!skip && !updateExprInMap(expr, map, useCount)) {
if (!skip && !updateExprInMap(expr, map, useCount, conditional)) {
val uc = useCount.sign
childrenToRecurse(expr).foreach(updateExprTree(_, map, uc))
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc))
val recurseChildren = childrenToRecurse(expr)
recurseChildren.alwaysChildren.foreach { child =>
updateExprTree(child, map, uc, conditionalsEnabled, conditional, skipExpressions)
}

/**
* If the `commonExpressions` already appears in the equivalence map, calling
* `updateExprTree` will increase the `useCount` and mark it as a common subexpression.
* Otherwise, `updateExprTree` will recursively add `commonExpressions` and its descendant to
* the equivalence map, in case they also appear in other places. For example,
* `If(a + b > 1, a + b + c, a + b + c)`, `a + b` also appears in the condition and should
* be treated as common subexpression.
*/
val commonExpressions = recurseChildren.commonChildren.flatMap { exprs =>
if (exprs.nonEmpty) {
getCommonExprs(exprs)
} else {
Nil
}
}
commonExpressions.foreach { ce =>
updateExprTree(ce.e, map, uc, conditionalsEnabled, conditional, skipExpressions)
}

if (conditionalsEnabled) {
// Add all conditional expressions, skipping those that were already counted as common
// expressions.
recurseChildren.conditionalChildren.foreach { cc =>
updateExprTree(cc, map, uc, true, true, commonExpressions.toSet)
}
}
}
}

Expand All @@ -208,7 +254,7 @@ class EquivalentExpressions(

// Exposed for testing.
private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = {
equivalenceMap.filter(_._2.useCount > count).toSeq.sortBy(_._1.height).map(_._2)
equivalenceMap.filter(_._2.getUseCount() > count).toSeq.sortBy(_._1.height).map(_._2)
}

/**
Expand All @@ -225,8 +271,11 @@ class EquivalentExpressions(
def debugString(all: Boolean = false): String = {
val sb = new java.lang.StringBuilder()
sb.append("Equivalent expressions:\n")
equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { stats =>
sb.append(" ").append(s"${stats.expr}: useCount = ${stats.useCount}").append('\n')
equivalenceMap.values.filter(stats => all || stats.getUseCount() > 1).foreach { stats =>
sb.append(" ")
.append(s"${stats.expr}: useCount = ${stats.useCount} ")
.append(s"conditionalUseCount = ${stats.conditionalUseCount}")
.append('\n')
}
sb.toString()
}
Expand Down Expand Up @@ -255,4 +304,32 @@ case class ExpressionEquals(e: Expression) {
* Instead of appending to a mutable list/buffer of Expressions, just update the "flattened"
* useCount in this wrapper in-place.
*/
case class ExpressionStats(expr: Expression)(var useCount: Int)
case class ExpressionStats(expr: Expression)(
var useCount: Int = 1,
var conditionalUseCount: Int = 0) {
def getUseCount(): Int = if (useCount > 0) {
useCount + conditionalUseCount
} else {
0
}
}

/**
* A wrapper for the different types of children of expressions.
*
* `alwaysChildren` are child expressions that will always be evaluated and should be considered
* for subexpressions.
*
* `commonChildren` are children such that each of the children might be evaluated, but at last once
* will definitely be evaluated. If there are any common expressions among them, those expressions
* will definitely be evaluated and should be considered for subexpressions.
*
* `conditionalChildren` are children that are conditionally evaluated, such as in If, CaseWhen,
* or Coalesce expressions, and should only be considered for subexpressions if they are evaluated
* non-conditionally elsewhere.
*/
case class RecurseChildren(
alwaysChildren: Seq[Expression],
commonChildren: Seq[Seq[Expression]] = Nil,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have an example this commonChildren?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the docs a little bit to clarify. Currently it's only If and CaseWhen expressions that commonChildren applies too, should I put one of those as an example in the doc?

conditionalChildren: Seq[Expression] = Nil
)
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,12 @@ trait ConditionalExpression extends Expression {
* so that we can eagerly evaluate the common expressions of a group.
*/
def branchGroups: Seq[Seq[Expression]]

/**
* Returns children expressions which are conditionally evaluated. If the same expression
* will be always evaluated elsewhere, we can make it a subexpression.
*/
def conditionallyEvaluatedInputs: Seq[Expression]
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi

override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, falseValue))

override def conditionallyEvaluatedInputs: Seq[Expression] = Seq(trueValue, falseValue)

final override val nodePatterns : Seq[TreePattern] = Seq(IF)

override def checkInputDataTypes(): TypeCheckResult = {
Expand Down Expand Up @@ -241,29 +243,12 @@ case class CaseWhen(
}

override def branchGroups: Seq[Seq[Expression]] = {
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
// because a subexpression in conditions will be run no matter which condition is matched
// if it is shared among conditions, but it doesn't need to be shared in values. Similarly,
// a subexpression among values doesn't need to be in conditions because no matter which
// condition is true, it will be evaluated.
val conditions = if (branches.length > 1) {
branches.map(_._1)
} else {
// If there is only one branch, the first condition is already covered by
// `alwaysEvaluatedInputs` and we should exclude it here.
Nil
}
// For an expression to be in all branch values of a CaseWhen statement, it must also be in
// the elseValue.
val values = if (elseValue.nonEmpty) {
branches.map(_._2) ++ elseValue
} else {
Nil
}

Seq(conditions, values)
// If there's an else value then we will definitely evaluate at least one branch value
if (elseValue.isDefined) Seq(branches.map(_._2) ++ elseValue) else Nil
}

override def conditionallyEvaluatedInputs: Seq[Expression] = children.tail

override def eval(input: InternalRow): Any = {
var i = 0
val size = branches.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,9 @@ case class Coalesce(children: Seq[Expression])
withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1))
}

override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) {
// If there is only one child, the first child is already covered by
// `alwaysEvaluatedInputs` and we should exclude it here.
Seq(children)
} else {
Nil
}
override def branchGroups: Seq[Seq[Expression]] = Nil

override def conditionallyEvaluatedInputs: Seq[Expression] = children.tail

override def eval(input: InternalRow): Any = {
var result: Any = null
Expand Down Expand Up @@ -348,7 +344,9 @@ case class NaNvl(left: Expression, right: Expression)
copy(left = alwaysEvaluatedInputs.head)
}

override def branchGroups: Seq[Seq[Expression]] = Seq(children)
override def branchGroups: Seq[Seq[Expression]] = Nil

override def conditionallyEvaluatedInputs: Seq[Expression] = right :: Nil

override def eval(input: InternalRow): Any = {
val value = left.eval(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED =
buildConf("spark.sql.subexpressionElimination.conditionals.enabled")
.internal()
.doc("When true, common conditional subexpressions will be eliminated.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val CASE_SENSITIVE = buildConf(SqlApiConfHelper.CASE_SENSITIVE_KEY)
.internal()
.doc("Whether the query analyzer should be case sensitive or not. " +
Expand Down Expand Up @@ -7314,6 +7322,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def subexpressionEliminationSkipForShotcutExpr: Boolean =
getConf(SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR)

def subexpressionEliminationConditionalsEnabled: Boolean =
getConf(SUBEXPRESSION_ELIMINATION_CONDITIONALS_ENABLED)

def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD)

def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS)
Expand Down
Loading