diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 72ff9361d8f7..cddb3416b084 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ParameterizedBoundReference} import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable /** @@ -40,23 +40,50 @@ class EquivalentExpressions { override def hashCode: Int = e.semanticHash() } + /** + * Wrapper around an Expression that provides structural semantic equality. + */ + case class StructuralExpr(e: Expression) { + def normalized(expr: Expression): Expression = { + expr.transformUp { + case b: BoundReference => + b.copy(ordinal = -1) + } + } + override def equals(o: Any): Boolean = o match { + case other: StructuralExpr => + normalized(e).semanticEquals(normalized(other.e)) + case _ => false + } + + override def hashCode: Int = normalized(e).semanticHash() + } + + type EquivalenceMap = mutable.HashMap[Expr, mutable.ArrayBuffer[Expression]] + // For each expression, the set of equivalent expressions. private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]] + // For each expression, the set of structurally equivalent expressions. + // Among expressions with same structure, there are different sub-set of expressions + // which are semantically different to each others. Thus, under each key, the value is + // the map structure used to do semantically sub-expression elimination. + private val structEquivalenceMap = mutable.HashMap.empty[StructuralExpr, EquivalenceMap] + /** * Adds each expression to this data structure, grouping them with existing equivalent * expressions. Non-recursive. * Returns true if there was already a matching expression. */ - def addExpr(expr: Expression): Boolean = { + def addExpr(expr: Expression, exprMap: EquivalenceMap = this.equivalenceMap): Boolean = { if (expr.deterministic) { val e: Expr = Expr(expr) - val f = equivalenceMap.get(e) + val f = exprMap.get(e) if (f.isDefined) { f.get += expr true } else { - equivalenceMap.put(e, mutable.ArrayBuffer(expr)) + exprMap.put(e, mutable.ArrayBuffer(expr)) false } } else { @@ -65,35 +92,97 @@ class EquivalentExpressions { } /** - * Adds the expression to this data structure recursively. Stops if a matching expression - * is found. That is, if `expr` has already been added, its children are not added. + * Adds each expression to structural expression data structure, grouping them with existing + * structurally equivalent expressions. Non-recursive. Returns false if this doesn't add input + * expression actually. + */ + def addStructExpr(expr: Expression): Boolean = { + if (expr.deterministic) { + // For structural equivalent expressions, we need to pass in int type ordinals into + // split functions. If the number of ordinals is more than JVM function limit, we skip + // this expression. + // We calculate function parameter length by the number of ints plus `INPUT_ROW` plus + // a int type result array index. + val refs = expr.collect { + case _: BoundReference => Literal(0) + } + val parameterLength = CodeGenerator.calculateParamLength(refs) + 2 + if (CodeGenerator.isValidParamLength(parameterLength)) { + val e: StructuralExpr = StructuralExpr(expr) + val f = structEquivalenceMap.get(e) + if (f.isDefined) { + addExpr(expr, f.get) + } else { + val exprMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]] + addExpr(expr, exprMap) + structEquivalenceMap.put(e, exprMap) + } + true + } else { + false + } + } else { + false + } + } + + /** + * Checks if we skip add sub-expressions for given expression. */ - def addExprTree(expr: Expression): Unit = { - val skip = expr.isInstanceOf[LeafExpression] || + private def skipExpr(expr: Expression): Boolean = { + expr.isInstanceOf[LeafExpression] || // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. expr.find(_.isInstanceOf[LambdaVariable]).isDefined + } + + + // There are some special expressions that we should not recurse into all of its children. + // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) + // 2. If: common subexpressions will always be evaluated at the beginning, but the true and + // false expressions in `If` may not get accessed, according to the predicate + // expression. We should only recurse into the predicate expression. + // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain + // condition. We should only recurse into the first condition expression as it + // will always get accessed. + // 4. Coalesce: it's also a conditional expression, we should only recurse into the first + // children, because others may not get accessed. + // 5. SortPrefix: skipt the direct child of SortPrefix which is an unevaluable SortOrder. + private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { + case _: CodegenFallback => Nil + case i: If => i.predicate :: Nil + case c: CaseWhen => c.children.head :: Nil + case c: Coalesce => c.children.head :: Nil + case s: SortPrefix => s.child.child :: Nil + case other => other.children + } + + /** + * Adds the expression to this data structure recursively. Stops if a matching expression + * is found. That is, if `expr` has already been added, its children are not added. + */ + def addExprTree( + expr: Expression, + exprMap: EquivalenceMap = this.equivalenceMap): Unit = { + val skip = skipExpr(expr) - // There are some special expressions that we should not recurse into all of its children. - // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) - // 2. If: common subexpressions will always be evaluated at the beginning, but the true and - // false expressions in `If` may not get accessed, according to the predicate - // expression. We should only recurse into the predicate expression. - // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain - // condition. We should only recurse into the first condition expression as it - // will always get accessed. - // 4. Coalesce: it's also a conditional expression, we should only recurse into the first - // children, because others may not get accessed. - def childrenToRecurse: Seq[Expression] = expr match { - case _: CodegenFallback => Nil - case i: If => i.predicate :: Nil - case c: CaseWhen => c.children.head :: Nil - case c: Coalesce => c.children.head :: Nil - case other => other.children + if (!skip && !addExpr(expr, exprMap)) { + childrenToRecurse(expr).foreach(addExprTree(_, exprMap)) } + } + + /** + * Adds the expression to structural data structure recursively. Returns false if this doesn't add + * the input expression actually. + */ + def addStructuralExprTree(expr: Expression): Boolean = { + val skip = skipExpr(expr) || expr.isInstanceOf[CodegenFallback] - if (!skip && !addExpr(expr)) { - childrenToRecurse.foreach(addExprTree) + if (!skip && addStructExpr(expr)) { + childrenToRecurse(expr).foreach(addStructuralExprTree) + true + } else { + false } } @@ -112,6 +201,18 @@ class EquivalentExpressions { equivalenceMap.values.map(_.toSeq).toSeq } + def getStructurallyEquivalentExprs(e: Expression): Seq[Seq[Expression]] = { + val key = StructuralExpr(e) + structEquivalenceMap.get(key).map(_.values.map(_.toSeq).toSeq).getOrElse(Seq.empty) + } + + /** + * Returns all the structurally equivalent sets of expressions. + */ + def getAllStructuralExpressions: Map[StructuralExpr, EquivalenceMap] = { + structEquivalenceMap.toMap + } + /** * Returns the state of the data structure as a string. If `all` is false, skips sets of * equivalent expressions with cardinality 1. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4632957e7afd..300f7b39762b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -167,13 +167,20 @@ abstract class Expression extends TreeNode[Expression] { "" } + // Appends necessary parameters from sub-expression elimination. + val arguments = (ctx.subExprEliminationParameters.map(_.ordinalParam) ++ + Seq(ctx.INPUT_ROW)).mkString(", ") + + val parameterString = (ctx.subExprEliminationParameters.map(p => s"int ${p.ordinalParam}") ++ + Seq(s"InternalRow ${ctx.INPUT_ROW}")).mkString(", ") + val javaType = CodeGenerator.javaType(dataType) val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) val funcFullName = ctx.addNewFunction(funcName, s""" - |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { + |private $javaType $funcName($parameterString) { | ${eval.code} | $setIsNull | return ${eval.value}; @@ -181,7 +188,7 @@ abstract class Expression extends TreeNode[Expression] { """.stripMargin) eval.value = JavaCode.variable(newValue, dataType) - eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = code"$javaType $newValue = $funcFullName($arguments);" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 660a1dbaf0aa..3228ff83f505 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -402,13 +402,18 @@ class CodegenContext { * * equivalentExpressions will match the tree containing `col1 + col2` and it will only * be evaluated once. + * + * Visible for testing. */ - private val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + private[expressions] val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // Foreach expression that is participating in subexpression elimination, the state to use. // Visible for testing. private[expressions] var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState] + // This tracks the current parameters needed to pass into functions of common sub-expressions. + var subExprEliminationParameters = Seq.empty[ParameterizedBoundReference] + // The collection of sub-expression result resetting methods that need to be called on each row. private val subexprFunctions = mutable.ArrayBuffer.empty[String] @@ -826,10 +831,11 @@ class CodegenContext { if (INPUT_ROW == null || currentVars != null) { expressions.mkString("\n") } else { + val structuralSubExpressionsArgs = subExprEliminationParameters.map("int" -> _.ordinalParam) splitExpressions( expressions, funcName, - ("InternalRow", INPUT_ROW) +: extraArguments, + ("InternalRow", INPUT_ROW) +: (extraArguments ++ structuralSubExpressionsArgs), returnType, makeSplitFunction, foldFunctions) @@ -1033,7 +1039,7 @@ class CodegenContext { val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree) + expressions.foreach(equivalentExpressions.addExprTree(_)) // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. @@ -1050,11 +1056,105 @@ class CodegenContext { } /** - * Checks and sets up the state and codegen for subexpression elimination. This finds the - * common subexpressions, generates the functions that evaluate those expressions and populates - * the mapping of common subexpressions to the generated functions. - */ - private def subexpressionElimination(expressions: Seq[Expression]): Unit = { + * This is for sub-expression elimination targeting structurally equivalent expressions. + * This is only supported in non whole-stage codegen. + * + * Two expressions are structurally equivalent if they are the same except for the individual + * input slot in current processing row (i.e., `INPUT_ROW`). + * + * For example, expression a is input[1] + input[2], expression b is input[3] + input[4]. They + * are not semantically equivalent in SparkSQL, but they have the same computation on different + * input data. + * + * This method generates a common function for a set of structurally equivalent expressions. + * Among the set, the expressions with same semantics are replaced with a function call to the + * generated function, by passing in input slots of current processing row. + */ + private def structuralSubexpressionElimination(expressions: Seq[Expression]): Seq[Expression] = { + // Add each expression tree and compute the structurally common subexpressions. + // Those expressions are not added into structurally common subexpressions, defer them to + // semantically common subexpression. + val nonApplicableExprs = + expressions.filterNot(equivalentExpressions.addStructuralExprTree) + + val structuralExprs = equivalentExpressions.getAllStructuralExpressions + + structuralExprs.flatMap { case (key, exprMap) => + val exprGroups = exprMap.values.flatMap { + case a if a.size > 1 => Some(a) + case _ => None + } + if (exprGroups.isEmpty) { + None + } else { + Some((key.e, exprGroups)) + } + }.foreach { case (expr, exprGroups) => + val parameterizedExpr = expr.transformUp { + case b: BoundReference => + val param = freshName("ordinal") + ParameterizedBoundReference(param, b.dataType, b.nullable) + } + val parameters = parameterizedExpr.collect { + case b: ParameterizedBoundReference => b + } + val resultIndex = freshName("resultIndex") + val parameterString = (parameters.map(p => s"int ${p.ordinalParam}") ++ + Seq(s"InternalRow $INPUT_ROW", s"int $resultIndex")).mkString(", ") + + val fnName = freshName("subExpr") + val isNull = addMutableState(s"${JAVA_BOOLEAN}[]", "subExprIsNull", + v => s"$v = new ${JAVA_BOOLEAN}[${exprGroups.size}];") + + val resultJavaType = javaType(expr.dataType) + val value = if (resultJavaType.contains("[]")) { + // If this expr returns an (multi-dimension) array. + val baseIdx = resultJavaType.indexOf("[]") + val baseType = resultJavaType.substring(0, baseIdx) + val arrayPart = resultJavaType.substring(baseIdx, resultJavaType.length) + addMutableState(s"$resultJavaType[]", "subExprValue", + v => s"$v = new $baseType[${exprGroups.size}]$arrayPart;") + } else { + addMutableState(s"${javaType(expr.dataType)}[]", "subExprValue", + v => s"$v = new ${javaType(expr.dataType)}[${exprGroups.size}];") + } + + // Generate the code for this expression tree and wrap it in a function. + // Sets the current parameters. + subExprEliminationParameters = parameters + val eval = parameterizedExpr.genCode(this) + val fn = + s""" + |private void $fnName($parameterString) { + | ${eval.code} + | $isNull[$resultIndex] = ${eval.isNull}; + | $value[$resultIndex] = ${eval.value}; + |} + """.stripMargin + + val funcName = addNewFunction(fnName, fn) + + // Generate sub-expr function calls to the generated function, for each group + // of semantically equivalent sub-expressions. + exprGroups.zipWithIndex.foreach { case (exprs, idx) => + val e = exprs.head + val arguments = (e.collect { + case b: BoundReference => b.ordinal.toString + } ++ Seq(INPUT_ROW, idx)).mkString(", ") + subexprFunctions += s"$funcName($arguments);" + + val state = SubExprEliminationState( + JavaCode.isNullGlobal(s"$isNull[$idx]"), + JavaCode.global(s"$value[$idx]", expr.dataType)) + subExprEliminationExprs ++= exprs.map(_ -> state).toMap + } + } + subExprEliminationParameters = Seq.empty + + nonApplicableExprs + } + + private def semanticSubexpressionElimination(expressions: Seq[Expression]): Unit = { // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree(_)) @@ -1100,6 +1200,20 @@ class CodegenContext { } } + /** + * Checks and sets up the state and codegen for subexpression elimination. This finds the + * common subexpressions, generates the functions that evaluate those expressions and populates + * the mapping of common subexpressions to the generated functions. + */ + private def subexpressionElimination(expressions: Seq[Expression]): Unit = { + val exprs = if (SQLConf.get.structuralSubexpressionEliminationEnabled) { + structuralSubexpressionElimination(expressions) + } else { + expressions + } + semanticSubexpressionElimination(exprs) + } + /** * Generates code for expressions. If doSubexpressionElimination is true, subexpression * elimination will be performed. Subexpression elimination assumes that the code for each diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ParameterizedBoundReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ParameterizedBoundReference.scala new file mode 100644 index 000000000000..c5110907e6c5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ParameterizedBoundReference.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.DataType + +/** + * This bound reference points to a parameterized slot in an input tuple. It is used in + * common sub-expression elimination. When some common sub-expressions have same structural + * but different slots of input tuple, we replace `BoundReference` with this parameterized + * version. The slot position is parameterized and is given at runtime. + */ +case class ParameterizedBoundReference(ordinalParam: String, dataType: DataType, nullable: Boolean) + extends LeafExpression { + + override def toString: String = s"input[$ordinalParam, ${dataType.simpleString}, $nullable]" + + override def eval(input: InternalRow): Any = { + throw new UnsupportedOperationException( + "ParameterizedBoundReference does not implement eval") + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + assert(ctx.currentVars == null && ctx.INPUT_ROW != null, + "ParameterizedBoundReference can not be used in whole-stage codegen yet.") + val javaType = JavaCode.javaType(dataType) + val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinalParam) + if (nullable) { + ev.copy(code = + code""" + |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinalParam); + |$javaType ${ev.value} = ${ev.isNull} ? + | ${CodeGenerator.defaultValue(dataType)} : ($value); + """.stripMargin) + } else { + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d9b0a72618c7..886369be3277 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -394,6 +394,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STRUCTURAL_SUBEXPRESSION_ELIMINATION_ENABLED = + buildConf("spark.sql.structuralSubexpressionElimination.enabled") + .internal() + .doc("When true, structurally equivalent common subexpressions will be eliminated. " + + "This config is effective only when spark.sql.subexpressionElimination.enabled is true.") + .booleanConf + .createWithDefault(true) + val CASE_SENSITIVE = buildConf("spark.sql.caseSensitive") .internal() .doc("Whether the query analyzer should be case sensitive or not. " + @@ -2232,6 +2240,9 @@ class SQLConf extends Serializable with Logging { def subexpressionEliminationEnabled: Boolean = getConf(SUBEXPRESSION_ELIMINATION_ENABLED) + def structuralSubexpressionEliminationEnabled: Boolean = + getConf(STRUCTURAL_SUBEXPRESSION_ELIMINATION_ENABLED) + def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD) def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 28d2607e6e43..f72d9f95ea11 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -465,49 +465,50 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") { + withSQLConf(SQLConf.STRUCTURAL_SUBEXPRESSION_ELIMINATION_ENABLED.key -> "false") { + val ref = BoundReference(0, IntegerType, true) + val add1 = Add(ref, ref) + val add2 = Add(add1, add1) + val dummy = SubExprEliminationState( + JavaCode.variable("dummy", BooleanType), + JavaCode.variable("dummy", BooleanType)) + + // raw testing of basic functionality + { + val ctx = new CodegenContext + val e = ref.genCode(ctx) + // before + ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) + assert(ctx.subExprEliminationExprs.contains(ref)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + } - val ref = BoundReference(0, IntegerType, true) - val add1 = Add(ref, ref) - val add2 = Add(add1, add1) - val dummy = SubExprEliminationState( - JavaCode.variable("dummy", BooleanType), - JavaCode.variable("dummy", BooleanType)) - - // raw testing of basic functionality - { - val ctx = new CodegenContext - val e = ref.genCode(ctx) - // before - ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) - assert(ctx.subExprEliminationExprs.contains(ref)) - // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { + // emulate an actual codegen workload + { + val ctx = new CodegenContext + // before + ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE + assert(ctx.subExprEliminationExprs.contains(add1)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(ref -> dummy)) { + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) assert(ctx.subExprEliminationExprs.contains(add1)) assert(!ctx.subExprEliminationExprs.contains(ref)) - Seq.empty - } - // after - assert(ctx.subExprEliminationExprs.nonEmpty) - assert(ctx.subExprEliminationExprs.contains(ref)) - assert(!ctx.subExprEliminationExprs.contains(add1)) - } - - // emulate an actual codegen workload - { - val ctx = new CodegenContext - // before - ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE - assert(ctx.subExprEliminationExprs.contains(add1)) - // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(ref -> dummy)) { - assert(ctx.subExprEliminationExprs.contains(ref)) - assert(!ctx.subExprEliminationExprs.contains(add1)) - Seq.empty } - // after - assert(ctx.subExprEliminationExprs.nonEmpty) - assert(ctx.subExprEliminationExprs.contains(add1)) - assert(!ctx.subExprEliminationExprs.contains(ref)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 54ef9641bee0..65e57a8b1f00 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -50,6 +50,8 @@ class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper * instances of the expression. */ case class BadCodegenExpression() extends LeafExpression { + // This is invalid expression. Prevent structural sub-expression elimination work on this. + override lazy val deterministic: Boolean = false override def nullable: Boolean = false override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala new file mode 100644 index 000000000000..5adb66401a92 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext + +class StructuralSubexpressionEliminationSuite extends SparkFunSuite { + test("Structurally Expression Equivalence") { + val equivalence = new EquivalentExpressions + assert(equivalence.getAllStructuralExpressions.isEmpty) + + val oneA = Literal(1) + val oneB = Literal(1) + val twoA = Literal(2) + + assert(equivalence.getStructurallyEquivalentExprs(oneA).isEmpty) + assert(equivalence.getStructurallyEquivalentExprs(twoA).isEmpty) + + // Add oneA and test if it is returned. Since it is a group of one, it does not. + equivalence.addStructExpr(oneA) + assert(equivalence.getStructurallyEquivalentExprs(oneA).size == 1) + assert(equivalence.getStructurallyEquivalentExprs(oneA)(0).size == 1) + + assert(equivalence.getStructurallyEquivalentExprs(twoA).isEmpty) + equivalence.addStructExpr(oneA) + assert(equivalence.getStructurallyEquivalentExprs(oneA).size == 1) + assert(equivalence.getStructurallyEquivalentExprs(oneA)(0).size == 2) + + // Add B and make sure they can see each other. + equivalence.addStructExpr(oneB) + // Use exists and reference equality because of how equals is defined. + assert(equivalence.getStructurallyEquivalentExprs(oneA).flatten.exists(_ eq oneB)) + assert(equivalence.getStructurallyEquivalentExprs(oneA).flatten.exists(_ eq oneA)) + assert(equivalence.getStructurallyEquivalentExprs(oneB).flatten.exists(_ eq oneA)) + assert(equivalence.getStructurallyEquivalentExprs(oneB).flatten.exists(_ eq oneB)) + assert(equivalence.getStructurallyEquivalentExprs(twoA).isEmpty) + + assert(equivalence.getAllStructuralExpressions.size == 1) + assert(equivalence.getAllStructuralExpressions.values.head.values.flatten.toSeq.size == 3) + assert(equivalence.getAllStructuralExpressions.values.head.values.flatten.toSeq.contains(oneA)) + assert(equivalence.getAllStructuralExpressions.values.head.values.flatten.toSeq.contains(oneB)) + + val add1 = Add(oneA, oneB) + val add2 = Add(oneA, oneB) + + equivalence.addStructExpr(add1) + equivalence.addStructExpr(add2) + + assert(equivalence.getAllStructuralExpressions.size == 2) + assert(equivalence.getStructurallyEquivalentExprs(add2).flatten.exists(_ eq add1)) + assert(equivalence.getStructurallyEquivalentExprs(add2).flatten.size == 2) + assert(equivalence.getStructurallyEquivalentExprs(add1).flatten.exists(_ eq add2)) + } + + test("Expression equivalence - non deterministic") { + val sum = Add(Rand(0), Rand(0)) + val equivalence = new EquivalentExpressions + equivalence.addStructExpr(sum) + equivalence.addStructExpr(sum) + assert(equivalence.getAllStructuralExpressions.isEmpty) + } + + test("CodegenFallback and children") { + val one = Literal(1) + val two = Literal(2) + val add = Add(one, two) + val fallback = CodegenFallbackExpression(add) + val add2 = Add(add, fallback) + + val equivalence = new EquivalentExpressions + equivalence.addStructuralExprTree(add2) + // `fallback` and the `add` inside should not be added + assert(equivalence.getAllStructuralExpressions.values + .map(_.values.count(_.size > 1)).sum == 0) + assert(equivalence.getAllStructuralExpressions.values + .map(_.values.count(_.size == 1)).sum == 2) // add, add2 + } + + test("Children of conditional expressions") { + val condition = And(Literal(true), Literal(false)) + val add = Add(Literal(1), Literal(2)) + val ifExpr = If(condition, add, add) + + val equivalence = new EquivalentExpressions + equivalence.addStructuralExprTree(ifExpr) + // the `add` inside `If` should not be added + assert(equivalence.getAllStructuralExpressions.values + .map(_.values.count(_.size > 1)).sum == 0) + // only ifExpr and its predicate expression + assert(equivalence.getAllStructuralExpressions.values + .map(_.values.count(_.size == 1)).sum == 2) + } + + test("Expressions not for structural expr elimination can go non-structural mode") { + val fallback1 = CodegenFallbackExpression(Literal(1)) + val fallback2 = CodegenFallbackExpression(Literal(1)) + + val ctx = new CodegenContext() + ctx.generateExpressions(Seq(fallback1, fallback2), doSubexpressionElimination = true) + assert(ctx.equivalentExpressions.getAllStructuralExpressions.isEmpty) + assert(ctx.equivalentExpressions.getEquivalentExprs(fallback1).length == 2) + } +}