From c4cb28f6b83ed0791278dbcefc24b3bce5fb0ed1 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Fri, 17 Mar 2023 18:15:25 +0800 Subject: [PATCH 1/2] Subexpression elimination support shortcut expression --- .../expressions/EquivalentExpressions.scala | 23 ++++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 13 ++++++++ .../SubexpressionEliminationSuite.scala | 31 ++++++++++++++++++- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index f47391c04929..aea91c0aef87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -23,6 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils /** @@ -30,7 +31,9 @@ import org.apache.spark.util.Utils * to this class and they subsequently query for expression equality. Expression trees are * considered equal if for the same input(s), the same result is produced. */ -class EquivalentExpressions { +class EquivalentExpressions( + skipForShortcutEnable: Boolean = SQLConf.get.subexpressionEliminationSkipForShotcutExpr) { + // For each expression, the set of equivalent expressions. private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] @@ -129,13 +132,27 @@ class EquivalentExpressions { } } + private def skipForShortcut(expr: Expression): Expression = { + if (skipForShortcutEnable) { + // The subexpression may not need to eval even if it appears more than once. + // e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true. + expr match { + case and: And => skipForShortcut(and.left) + case or: Or => skipForShortcut(or.left) + case other => other + } + } else { + expr + } + } + // There are some special expressions that we should not recurse into all of its children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) // 2. ConditionalExpression: use its children that will always be evaluated. private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { case _: CodegenFallback => Nil - case c: ConditionalExpression => c.alwaysEvaluatedInputs - case other => other.children + case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut) + case other => skipForShortcut(other).children } // For some special expressions we cannot just recurse into all of its children, but we can diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d369eaf45072..ae69e4cf6985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -864,6 +864,16 @@ object SQLConf { .checkValue(_ >= 0, "The maximum must not be negative") .createWithDefault(100) + val SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR = + buildConf("spark.sql.subexpressionElimination.skipForShortcutExpr") + .internal() + .doc("When true, shortcut eliminate subexpression with `AND`, `OR`. " + + "The subexpression may not need to eval even if it appears more than once. " + + "e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true.") + .version("3.5.0") + .booleanConf + .createWithDefault(false) + val CASE_SENSITIVE = buildConf("spark.sql.caseSensitive") .internal() .doc("Whether the query analyzer should be case sensitive or not. " + @@ -4610,6 +4620,9 @@ class SQLConf extends Serializable with Logging { def subexpressionEliminationCacheMaxEntries: Int = getConf(SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES) + def subexpressionEliminationSkipForShotcutExpr: Boolean = + getConf(SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR) + def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD) def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 44d8ea3a112e..f369635a3267 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.expressions +import java.util.Properties + import org.apache.spark.{SparkFunSuite, TaskContext, TaskContextImpl} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -424,7 +426,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel test("SPARK-38333: PlanExpression expression should skip addExprTree function in Executor") { try { // suppose we are in executor - val context1 = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, null, null, cpus = 0) + val context1 = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, new Properties, null, cpus = 0) TaskContext.setTaskContext(context1) val equivalence = new EquivalentExpressions @@ -465,6 +467,33 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel val cseState = equivalence.getExprState(expr) assert(hasMatching == cseState.isDefined) } + + test("SPARK-42815: Subexpression elimination support shortcut conditional expression") { + val add = Add(Literal(1), Literal(0)) + val equal = EqualTo(add, add) + + def checkShortcut(expr: Expression, numCommonExpr: Int): Unit = { + val e1 = If(expr, Literal(1), Literal(2)) + val ee1 = new EquivalentExpressions(true) + ee1.addExprTree(e1) + assert(ee1.getCommonSubexpressions.size == numCommonExpr) + + val e2 = expr + val ee2 = new EquivalentExpressions(true) + ee2.addExprTree(e2) + assert(ee2.getCommonSubexpressions.size == numCommonExpr) + } + + // shortcut right child + checkShortcut(And(Literal(false), equal), 0) + checkShortcut(Or(Literal(true), equal), 0) + checkShortcut(Not(And(Literal(true), equal)), 0) + + // always eliminate subexpression for left child + checkShortcut((And(equal, Literal(false))), 1) + checkShortcut(Or(equal, Literal(true)), 1) + checkShortcut(Not(And(equal, Literal(false))), 1) + } } case class CodegenFallbackExpression(child: Expression) From 3a714523c5a3f128a0663a877d4476a84ec75183 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Tue, 21 Mar 2023 19:41:31 +0800 Subject: [PATCH 2/2] address comments --- .../sql/catalyst/expressions/EquivalentExpressions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index aea91c0aef87..1a84859cc3a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -137,8 +137,8 @@ class EquivalentExpressions( // The subexpression may not need to eval even if it appears more than once. // e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true. expr match { - case and: And => skipForShortcut(and.left) - case or: Or => skipForShortcut(or.left) + case and: And => and.left + case or: Or => or.left case other => other } } else {