From 13f5ca6fb3c1afdede210fd8e90f01fd9758ef53 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 3 Sep 2019 13:34:04 +0800 Subject: [PATCH 1/6] Structurally equivalent subexpression elimination. --- .../catalyst/expressions/BoundAttribute.scala | 34 ++++ .../expressions/EquivalentExpressions.scala | 157 +++++++++++++++--- .../sql/catalyst/expressions/Expression.scala | 11 +- .../expressions/codegen/CodeGenerator.scala | 121 +++++++++++++- .../codegen/GenerateMutableProjection.scala | 2 +- .../codegen/GenerateUnsafeProjection.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 11 ++ .../expressions/CodeGenerationSuite.scala | 79 ++++----- .../ExpressionEvalHelperSuite.scala | 2 + ...ucturalSubexpressionEliminationSuite.scala | 112 +++++++++++++ 10 files changed, 456 insertions(+), 75 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 7ae5924b20fa..8af41eb359ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -65,6 +65,40 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } } +/** + * This bound reference points to a parameterized slot in an input tuple. It is used in + * common sub-expression elimination. When some common sub-expressions have same structural + * but different slots of input tuple, we replace `BoundReference` with this parameterized + * version. The slot position is parameterized and is given at runtime. + */ +case class ParameterizedBoundReference(parameter: String, dataType: DataType, nullable: Boolean) + extends LeafExpression { + + override def toString: String = s"input[$parameter, ${dataType.simpleString}, $nullable]" + + override def eval(input: InternalRow): Any = { + throw new UnsupportedOperationException( + "ParameterizedBoundReference does not implement eval") + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + assert(ctx.currentVars == null && ctx.INPUT_ROW != null, + "ParameterizedBoundReference can not be used in whole-stage codegen yet.") + val javaType = JavaCode.javaType(dataType) + val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, parameter) + if (nullable) { + ev.copy(code = + code""" + |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($parameter); + |$javaType ${ev.value} = ${ev.isNull} ? + | ${CodeGenerator.defaultValue(dataType)} : ($value); + """.stripMargin) + } else { + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + } + } +} + object BindReferences extends Logging { def bindReference[A <: Expression]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 72ff9361d8f7..c3b690531dda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback} import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable /** @@ -40,23 +40,47 @@ class EquivalentExpressions { override def hashCode: Int = e.semanticHash() } + /** + * Wrapper around an Expression that provides structural semantic equality. + */ + case class StructuralExpr(e: Expression) { + def normalized(expr: Expression): Expression = { + expr.transformUp { + case b: ParameterizedBoundReference => + b.copy(parameter = "") + } + } + override def equals(o: Any): Boolean = o match { + case other: StructuralExpr => + normalized(e).semanticEquals(normalized(other.e)) + case _ => false + } + + override def hashCode: Int = normalized(e).semanticHash() + } + + type EquivalenceMap = mutable.HashMap[Expr, mutable.ArrayBuffer[Expression]] + // For each expression, the set of equivalent expressions. private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]] + // For each expression, the set of structurally equivalent expressions. + private val structEquivalenceMap = mutable.HashMap.empty[StructuralExpr, EquivalenceMap] + /** * Adds each expression to this data structure, grouping them with existing equivalent * expressions. Non-recursive. * Returns true if there was already a matching expression. */ - def addExpr(expr: Expression): Boolean = { + def addExpr(expr: Expression, exprMap: EquivalenceMap = this.equivalenceMap): Boolean = { if (expr.deterministic) { val e: Expr = Expr(expr) - val f = equivalenceMap.get(e) + val f = exprMap.get(e) if (f.isDefined) { f.get += expr true } else { - equivalenceMap.put(e, mutable.ArrayBuffer(expr)) + exprMap.put(e, mutable.ArrayBuffer(expr)) false } } else { @@ -65,35 +89,102 @@ class EquivalentExpressions { } /** - * Adds the expression to this data structure recursively. Stops if a matching expression - * is found. That is, if `expr` has already been added, its children are not added. + * Adds each expression to structural expression data structure, grouping them with existing + * structurally equivalent expressions. Non-recursive. + */ + def addStructExpr(ctx: CodegenContext, expr: Expression): Unit = { + if (expr.deterministic) { + val refs = expr.collect { + case b: BoundReference => b + } + + // For structural equivalent expressions, we need to pass in int type ordinals into + // split functions. If the number of ordinals is more than JVM function limit, we skip + // this expression. + // We calculate function parameter length by the number of ints plus `INPUT_ROW` plus + // a int type result array index. + val parameterLength = CodeGenerator.calculateParamLength(refs.map(_ => Literal(0))) + 2 + if (CodeGenerator.isValidParamLength(parameterLength)) { + val parameterizedExpr = parameterizedBoundReferences(ctx, expr) + + val e: StructuralExpr = StructuralExpr(parameterizedExpr) + val f = structEquivalenceMap.get(e) + if (f.isDefined) { + addExpr(expr, f.get) + } else { + val exprMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]] + addExpr(expr, exprMap) + structEquivalenceMap.put(e, exprMap) + } + } + } + } + + /** + * Replaces bound references in given expression by parameterized bound references. */ - def addExprTree(expr: Expression): Unit = { - val skip = expr.isInstanceOf[LeafExpression] || + private def parameterizedBoundReferences(ctx: CodegenContext, expr: Expression): Expression = { + expr.transformUp { + case b: BoundReference => + val param = ctx.freshName("boundInput") + ParameterizedBoundReference(param, b.dataType, b.nullable) + } + } + + /** + * Checks if we skip add sub-expressions for given expression. + */ + private def skipExpr(expr: Expression): Boolean = { + expr.isInstanceOf[LeafExpression] || // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. expr.find(_.isInstanceOf[LambdaVariable]).isDefined + } + + + // There are some special expressions that we should not recurse into all of its children. + // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) + // 2. If: common subexpressions will always be evaluated at the beginning, but the true and + // false expressions in `If` may not get accessed, according to the predicate + // expression. We should only recurse into the predicate expression. + // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain + // condition. We should only recurse into the first condition expression as it + // will always get accessed. + // 4. Coalesce: it's also a conditional expression, we should only recurse into the first + // children, because others may not get accessed. + private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { + case _: CodegenFallback => Nil + case i: If => i.predicate :: Nil + case c: CaseWhen => c.children.head :: Nil + case c: Coalesce => c.children.head :: Nil + case s: SortPrefix => s.child.child :: Nil + case other => other.children + } + + /** + * Adds the expression to this data structure recursively. Stops if a matching expression + * is found. That is, if `expr` has already been added, its children are not added. + */ + def addExprTree( + expr: Expression, + exprMap: EquivalenceMap = this.equivalenceMap): Unit = { + val skip = skipExpr(expr) - // There are some special expressions that we should not recurse into all of its children. - // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) - // 2. If: common subexpressions will always be evaluated at the beginning, but the true and - // false expressions in `If` may not get accessed, according to the predicate - // expression. We should only recurse into the predicate expression. - // 3. CaseWhen: like `If`, the children of `CaseWhen` only get accessed in a certain - // condition. We should only recurse into the first condition expression as it - // will always get accessed. - // 4. Coalesce: it's also a conditional expression, we should only recurse into the first - // children, because others may not get accessed. - def childrenToRecurse: Seq[Expression] = expr match { - case _: CodegenFallback => Nil - case i: If => i.predicate :: Nil - case c: CaseWhen => c.children.head :: Nil - case c: Coalesce => c.children.head :: Nil - case other => other.children + if (!skip && !addExpr(expr, exprMap)) { + childrenToRecurse(expr).foreach(addExprTree(_, exprMap)) } + } + + /** + * Adds the expression to structural data structure recursively. Stops if a matching expression + * is found. + */ + def addStructuralExprTree(ctx: CodegenContext, expr: Expression): Unit = { + val skip = skipExpr(expr) || expr.isInstanceOf[CodegenFallback] - if (!skip && !addExpr(expr)) { - childrenToRecurse.foreach(addExprTree) + if (!skip) { + addStructExpr(ctx, expr) + childrenToRecurse(expr).foreach(addStructuralExprTree(ctx, _)) } } @@ -112,6 +203,20 @@ class EquivalentExpressions { equivalenceMap.values.map(_.toSeq).toSeq } + def getStructurallyEquivalentExprs(ctx: CodegenContext, e: Expression): Seq[Seq[Expression]] = { + val parameterizedExpr = parameterizedBoundReferences(ctx, e) + + val key = StructuralExpr(parameterizedExpr) + structEquivalenceMap.get(key).map(_.values.map(_.toSeq).toSeq).getOrElse(Seq.empty) + } + + /** + * Returns all the structurally equivalent sets of expressions. + */ + def getAllStructuralExpressions: Map[StructuralExpr, EquivalenceMap] = { + structEquivalenceMap.toMap + } + /** * Returns the state of the data structure as a string. If `all` is false, skips sets of * equivalent expressions with cardinality 1. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4632957e7afd..0c1442c44726 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -167,13 +167,20 @@ abstract class Expression extends TreeNode[Expression] { "" } + // Appends necessary parameters from sub-expression elimination. + val arguments = (ctx.subExprEliminationParameters.map(_.parameter) ++ + Seq(ctx.INPUT_ROW)).mkString(", ") + + val parameterString = (ctx.subExprEliminationParameters.map(p => s"int ${p.parameter}") ++ + Seq(s"InternalRow ${ctx.INPUT_ROW}")).mkString(", ") + val javaType = CodeGenerator.javaType(dataType) val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) val funcFullName = ctx.addNewFunction(funcName, s""" - |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { + |private $javaType $funcName($parameterString) { | ${eval.code} | $setIsNull | return ${eval.value}; @@ -181,7 +188,7 @@ abstract class Expression extends TreeNode[Expression] { """.stripMargin) eval.value = JavaCode.variable(newValue, dataType) - eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = code"$javaType $newValue = $funcFullName($arguments);" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 95fad412002e..98a6367c7252 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -408,6 +408,9 @@ class CodegenContext { // Foreach expression that is participating in subexpression elimination, the state to use. var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState] + // This tracks the current parameters needed to pass into functions of common sub-expressions. + var subExprEliminationParameters = Seq.empty[ParameterizedBoundReference] + // The collection of sub-expression result resetting methods that need to be called on each row. val subexprFunctions = mutable.ArrayBuffer.empty[String] @@ -825,10 +828,11 @@ class CodegenContext { if (INPUT_ROW == null || currentVars != null) { expressions.mkString("\n") } else { + val structuralSubExpressionsArgs = subExprEliminationParameters.map("int" -> _.parameter) splitExpressions( expressions, funcName, - ("InternalRow", INPUT_ROW) +: extraArguments, + ("InternalRow", INPUT_ROW) +: (extraArguments ++ structuralSubExpressionsArgs), returnType, makeSplitFunction, foldFunctions) @@ -1023,7 +1027,7 @@ class CodegenContext { val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // Add each expression tree and compute the common subexpressions. - expressions.foreach(equivalentExpressions.addExprTree) + expressions.foreach(equivalentExpressions.addExprTree(_)) // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. @@ -1040,11 +1044,103 @@ class CodegenContext { } /** - * Checks and sets up the state and codegen for subexpression elimination. This finds the - * common subexpressions, generates the functions that evaluate those expressions and populates - * the mapping of common subexpressions to the generated functions. + * Returns the code for subexpression elimination after splitting it if necessary. */ - private def subexpressionElimination(expressions: Seq[Expression]): Unit = { + def subexprFunctionsCode: String = { + // Whole-stage codegen's subexpression elimination is handled in another code path + splitExpressions(subexprFunctions, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) + } + + /** + * This is for sub-expression elimination targeting structurally equivalent expressions. + * This is only supported in non whole-stage codegen. + * + * Two expressions are structurally equivalent if they are the same except for the individual + * input slot in current processing row (i.e., `INPUT_ROW`). + * + * For example, expression a is input[1] + input[2], expression b is input[3] + input[4]. They + * are not semantically equivalent in SparkSQL, but they have the same computation on different + * input data. + * + * This method generates a common function for a set of structurally equivalent expressions. + * Among the set, the expressions with same semantics are replaced with a function call to the + * generated function, by passing in input slots of current processing row. + */ + private def structuralSubexpressionElimination(expressions: Seq[Expression]): Unit = { + // Add each expression tree and compute the structurally common subexpressions. + expressions.foreach(equivalentExpressions.addStructuralExprTree(this, _)) + + val structuralExprs = equivalentExpressions.getAllStructuralExpressions + + structuralExprs.flatMap { case (key, exprMap) => + val exprGroups = exprMap.values.flatMap { + case a if a.size > 1 => Some(a) + case _ => None + } + if (exprGroups.isEmpty) { + None + } else { + Some((key.e, exprGroups)) + } + }.foreach { case (expr, exprGroups) => + val parameters = expr.collect { + case b: ParameterizedBoundReference => b + } + val resultIndex = freshName("resultIndex") + val parameterString = (parameters.map(p => s"int ${p.parameter}") ++ + Seq(s"InternalRow $INPUT_ROW", s"int $resultIndex")).mkString(", ") + + val fnName = freshName("subExpr") + val isNull = addMutableState(s"${JAVA_BOOLEAN}[]", "subExprIsNull", + v => s"$v = new ${JAVA_BOOLEAN}[${exprGroups.size}];") + + val resultJavaType = javaType(expr.dataType) + val value = if (resultJavaType.contains("[]")) { + // If this expr returns an (multi-dimension) array. + val baseIdx = resultJavaType.indexOf("[]") + val baseType = resultJavaType.substring(0, baseIdx) + val arrayPart = resultJavaType.substring(baseIdx, resultJavaType.length) + addMutableState(s"$resultJavaType[]", "subExprValue", + v => s"$v = new $baseType[${exprGroups.size}]$arrayPart;") + } else { + addMutableState(s"${javaType(expr.dataType)}[]", "subExprValue", + v => s"$v = new ${javaType(expr.dataType)}[${exprGroups.size}];") + } + + // Generate the code for this expression tree and wrap it in a function. + // Sets the current parameters. + subExprEliminationParameters = parameters + val eval = expr.genCode(this) + val fn = + s""" + |private void $fnName($parameterString) { + | ${eval.code} + | $isNull[$resultIndex] = ${eval.isNull}; + | $value[$resultIndex] = ${eval.value}; + |} + """.stripMargin + + val funcName = addNewFunction(fnName, fn) + + // Generate sub-expr function calls to the generated function, for each group + // of semantically equivalent sub-expressions. + exprGroups.zipWithIndex.foreach { case (exprs, idx) => + val e = exprs.head + val arguments = (e.collect { + case b: BoundReference => b.ordinal.toString + } ++ Seq(INPUT_ROW, idx)).mkString(", ") + subexprFunctions += s"$funcName($arguments);" + + val state = SubExprEliminationState( + JavaCode.isNullGlobal(s"$isNull[$idx]"), + JavaCode.global(s"$value[$idx]", expr.dataType)) + subExprEliminationExprs ++= exprs.map(_ -> state).toMap + } + } + subExprEliminationParameters = Seq.empty + } + + private def semanticSubexpressionElimination(expressions: Seq[Expression]): Unit = { // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree(_)) @@ -1090,6 +1186,19 @@ class CodegenContext { } } + /** + * Checks and sets up the state and codegen for subexpression elimination. This finds the + * common subexpressions, generates the functions that evaluate those expressions and populates + * the mapping of common subexpressions to the generated functions. + */ + private def subexpressionElimination(expressions: Seq[Expression]): Unit = { + if (SQLConf.get.structuralSubexpressionEliminationEnabled) { + structuralSubexpressionElimination(expressions) + } else { + semanticSubexpressionElimination(expressions) + } + } + /** * Generates code for expressions. If doSubexpressionElimination is true, subexpression * elimination will be performed. Subexpression elimination assumes that the code for each diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 838bd1c679e4..2e018de07101 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -92,7 +92,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP } // Evaluate all the subexpressions. - val evalSubexpr = ctx.subexprFunctions.mkString("\n") + val evalSubexpr = ctx.subexprFunctionsCode val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index fb1d8a3c8e73..8da7f65bdeee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -299,7 +299,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});") // Evaluate all the subexpression. - val evalSubexpr = ctx.subexprFunctions.mkString("\n") + val evalSubexpr = ctx.subexprFunctionsCode val writeExpressions = writeExpressionsToBuffer( ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 52990cb6a244..3dc64618d3a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -345,6 +345,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STRUCTURAL_SUBEXPRESSION_ELIMINATION_ENABLED = + buildConf("spark.sql.structuralSubexpressionElimination.enabled") + .internal() + .doc("When true, structurally equivalent common subexpressions will be eliminated. " + + "This config is effective only when spark.sql.subexpressionElimination.enabled is true.") + .booleanConf + .createWithDefault(true) + val CASE_SENSITIVE = buildConf("spark.sql.caseSensitive") .internal() .doc("Whether the query analyzer should be case sensitive or not. " + @@ -2158,6 +2166,9 @@ class SQLConf extends Serializable with Logging { def subexpressionEliminationEnabled: Boolean = getConf(SUBEXPRESSION_ELIMINATION_ENABLED) + def structuralSubexpressionEliminationEnabled: Boolean = + getConf(STRUCTURAL_SUBEXPRESSION_ELIMINATION_ENABLED) + def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD) def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 4e64313da136..f6e7ed443347 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -465,49 +465,50 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") { + withSQLConf(SQLConf.STRUCTURAL_SUBEXPRESSION_ELIMINATION_ENABLED.key -> "false") { + val ref = BoundReference(0, IntegerType, true) + val add1 = Add(ref, ref) + val add2 = Add(add1, add1) + val dummy = SubExprEliminationState( + JavaCode.variable("dummy", BooleanType), + JavaCode.variable("dummy", BooleanType)) + + // raw testing of basic functionality + { + val ctx = new CodegenContext + val e = ref.genCode(ctx) + // before + ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) + assert(ctx.subExprEliminationExprs.contains(ref)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { + assert(ctx.subExprEliminationExprs.contains(add1)) + assert(!ctx.subExprEliminationExprs.contains(ref)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + } - val ref = BoundReference(0, IntegerType, true) - val add1 = Add(ref, ref) - val add2 = Add(add1, add1) - val dummy = SubExprEliminationState( - JavaCode.variable("dummy", BooleanType), - JavaCode.variable("dummy", BooleanType)) - - // raw testing of basic functionality - { - val ctx = new CodegenContext - val e = ref.genCode(ctx) - // before - ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value) - assert(ctx.subExprEliminationExprs.contains(ref)) - // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) { + // emulate an actual codegen workload + { + val ctx = new CodegenContext + // before + ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE + assert(ctx.subExprEliminationExprs.contains(add1)) + // call withSubExprEliminationExprs + ctx.withSubExprEliminationExprs(Map(ref -> dummy)) { + assert(ctx.subExprEliminationExprs.contains(ref)) + assert(!ctx.subExprEliminationExprs.contains(add1)) + Seq.empty + } + // after + assert(ctx.subExprEliminationExprs.nonEmpty) assert(ctx.subExprEliminationExprs.contains(add1)) assert(!ctx.subExprEliminationExprs.contains(ref)) - Seq.empty - } - // after - assert(ctx.subExprEliminationExprs.nonEmpty) - assert(ctx.subExprEliminationExprs.contains(ref)) - assert(!ctx.subExprEliminationExprs.contains(add1)) - } - - // emulate an actual codegen workload - { - val ctx = new CodegenContext - // before - ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE - assert(ctx.subExprEliminationExprs.contains(add1)) - // call withSubExprEliminationExprs - ctx.withSubExprEliminationExprs(Map(ref -> dummy)) { - assert(ctx.subExprEliminationExprs.contains(ref)) - assert(!ctx.subExprEliminationExprs.contains(add1)) - Seq.empty } - // after - assert(ctx.subExprEliminationExprs.nonEmpty) - assert(ctx.subExprEliminationExprs.contains(add1)) - assert(!ctx.subExprEliminationExprs.contains(ref)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 54ef9641bee0..65e57a8b1f00 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -50,6 +50,8 @@ class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper * instances of the expression. */ case class BadCodegenExpression() extends LeafExpression { + // This is invalid expression. Prevent structural sub-expression elimination work on this. + override lazy val deterministic: Boolean = false override def nullable: Boolean = false override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala new file mode 100644 index 000000000000..15fc15f7cbf6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types.{DataType, IntegerType} + +class StructuralSubexpressionEliminationSuite extends SparkFunSuite { + private val ctx = new CodegenContext + + test("Structurally Expression Equivalence") { + val equivalence = new EquivalentExpressions + assert(equivalence.getAllStructuralExpressions.isEmpty) + + val oneA = Literal(1) + val oneB = Literal(1) + val twoA = Literal(2) + var twoB = Literal(2) + + assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA).isEmpty) + assert(equivalence.getStructurallyEquivalentExprs(ctx, twoA).isEmpty) + + // Add oneA and test if it is returned. Since it is a group of one, it does not. + equivalence.addStructExpr(ctx, oneA) + assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA).size == 1) + assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA)(0).size == 1) + + assert(equivalence.getStructurallyEquivalentExprs(ctx, twoA).isEmpty) + equivalence.addStructExpr(ctx, oneA) + assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA).size == 1) + assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA)(0).size == 2) + + // Add B and make sure they can see each other. + equivalence.addStructExpr(ctx, oneB) + // Use exists and reference equality because of how equals is defined. + assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA).flatten.exists(_ eq oneB)) + assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA).flatten.exists(_ eq oneA)) + assert(equivalence.getStructurallyEquivalentExprs(ctx, oneB).flatten.exists(_ eq oneA)) + assert(equivalence.getStructurallyEquivalentExprs(ctx, oneB).flatten.exists(_ eq oneB)) + assert(equivalence.getStructurallyEquivalentExprs(ctx, twoA).isEmpty) + + assert(equivalence.getAllStructuralExpressions.size == 1) + assert(equivalence.getAllStructuralExpressions.values.head.values.flatten.toSeq.size == 3) + assert(equivalence.getAllStructuralExpressions.values.head.values.flatten.toSeq.contains(oneA)) + assert(equivalence.getAllStructuralExpressions.values.head.values.flatten.toSeq.contains(oneB)) + + val add1 = Add(oneA, oneB) + val add2 = Add(oneA, oneB) + + equivalence.addStructExpr(ctx, add1) + equivalence.addStructExpr(ctx, add2) + + assert(equivalence.getAllStructuralExpressions.size == 2) + assert(equivalence.getStructurallyEquivalentExprs(ctx, add2).flatten.exists(_ eq add1)) + assert(equivalence.getStructurallyEquivalentExprs(ctx, add2).flatten.size == 2) + assert(equivalence.getStructurallyEquivalentExprs(ctx, add1).flatten.exists(_ eq add2)) + } + + test("Expression equivalence - non deterministic") { + val sum = Add(Rand(0), Rand(0)) + val equivalence = new EquivalentExpressions + equivalence.addStructExpr(ctx, sum) + equivalence.addStructExpr(ctx, sum) + assert(equivalence.getAllStructuralExpressions.isEmpty) + } + + test("CodegenFallback and children") { + val one = Literal(1) + val two = Literal(2) + val add = Add(one, two) + val fallback = CodegenFallbackExpression(add) + val add2 = Add(add, fallback) + + val equivalence = new EquivalentExpressions + equivalence.addStructuralExprTree(ctx, add2) + // `fallback` and the `add` inside should not be added + assert(equivalence.getAllStructuralExpressions.values + .map(_.values.count(_.size > 1)).sum == 0) + assert(equivalence.getAllStructuralExpressions.values + .map(_.values.count(_.size == 1)).sum == 2) // add, add2 + } + + test("Children of conditional expressions") { + val condition = And(Literal(true), Literal(false)) + val add = Add(Literal(1), Literal(2)) + val ifExpr = If(condition, add, add) + + val equivalence = new EquivalentExpressions + equivalence.addStructuralExprTree(ctx, ifExpr) + // the `add` inside `If` should not be added + assert(equivalence.getAllStructuralExpressions.values + .map(_.values.count(_.size > 1)).sum == 0) + // only ifExpr and its predicate expression + assert(equivalence.getAllStructuralExpressions.values + .map(_.values.count(_.size == 1)).sum == 2) + } +} From f52cbde8e966e8b287ffdb159da65e1ab61f8303 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 9 Sep 2019 08:21:04 -0700 Subject: [PATCH 2/6] Solve merging conflict. --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 8bc2ad1ef8cd..2a6f8650025c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1053,14 +1053,6 @@ class CodegenContext { SubExprCodes(codes, localSubExprEliminationExprs.toMap) } - /** - * Returns the code for subexpression elimination after splitting it if necessary. - */ - def subexprFunctionsCode: String = { - // Whole-stage codegen's subexpression elimination is handled in another code path - splitExpressions(subexprFunctions, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) - } - /** * This is for sub-expression elimination targeting structurally equivalent expressions. * This is only supported in non whole-stage codegen. From cc0ee12eff89f24e4f5a197c012dec6ab85882a1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 9 Sep 2019 17:05:51 -0700 Subject: [PATCH 3/6] Address comments. --- .../catalyst/expressions/BoundAttribute.scala | 34 ----------- .../expressions/EquivalentExpressions.scala | 37 +++++++----- .../sql/catalyst/expressions/Expression.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 22 ++++--- .../codegen/ParameterizedBoundReference.scala | 57 +++++++++++++++++++ ...ucturalSubexpressionEliminationSuite.scala | 11 +++- 6 files changed, 106 insertions(+), 59 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ParameterizedBoundReference.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 8af41eb359ee..7ae5924b20fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -65,40 +65,6 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } } -/** - * This bound reference points to a parameterized slot in an input tuple. It is used in - * common sub-expression elimination. When some common sub-expressions have same structural - * but different slots of input tuple, we replace `BoundReference` with this parameterized - * version. The slot position is parameterized and is given at runtime. - */ -case class ParameterizedBoundReference(parameter: String, dataType: DataType, nullable: Boolean) - extends LeafExpression { - - override def toString: String = s"input[$parameter, ${dataType.simpleString}, $nullable]" - - override def eval(input: InternalRow): Any = { - throw new UnsupportedOperationException( - "ParameterizedBoundReference does not implement eval") - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - assert(ctx.currentVars == null && ctx.INPUT_ROW != null, - "ParameterizedBoundReference can not be used in whole-stage codegen yet.") - val javaType = JavaCode.javaType(dataType) - val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, parameter) - if (nullable) { - ev.copy(code = - code""" - |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($parameter); - |$javaType ${ev.value} = ${ev.isNull} ? - | ${CodeGenerator.defaultValue(dataType)} : ($value); - """.stripMargin) - } else { - ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) - } - } -} - object BindReferences extends Logging { def bindReference[A <: Expression]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index c3b690531dda..07241a54737b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ParameterizedBoundReference} import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable /** @@ -47,7 +47,7 @@ class EquivalentExpressions { def normalized(expr: Expression): Expression = { expr.transformUp { case b: ParameterizedBoundReference => - b.copy(parameter = "") + b.copy(ordinalParam = "") } } override def equals(o: Any): Boolean = o match { @@ -90,20 +90,20 @@ class EquivalentExpressions { /** * Adds each expression to structural expression data structure, grouping them with existing - * structurally equivalent expressions. Non-recursive. + * structurally equivalent expressions. Non-recursive. Returns false if this doesn't add input + * expression actually. */ - def addStructExpr(ctx: CodegenContext, expr: Expression): Unit = { + def addStructExpr(ctx: CodegenContext, expr: Expression): Boolean = { if (expr.deterministic) { - val refs = expr.collect { - case b: BoundReference => b - } - // For structural equivalent expressions, we need to pass in int type ordinals into // split functions. If the number of ordinals is more than JVM function limit, we skip // this expression. // We calculate function parameter length by the number of ints plus `INPUT_ROW` plus // a int type result array index. - val parameterLength = CodeGenerator.calculateParamLength(refs.map(_ => Literal(0))) + 2 + val refs = expr.collect { + case _: BoundReference => Literal(0) + } + val parameterLength = CodeGenerator.calculateParamLength(refs) + 2 if (CodeGenerator.isValidParamLength(parameterLength)) { val parameterizedExpr = parameterizedBoundReferences(ctx, expr) @@ -116,7 +116,12 @@ class EquivalentExpressions { addExpr(expr, exprMap) structEquivalenceMap.put(e, exprMap) } + true + } else { + false } + } else { + false } } @@ -126,7 +131,7 @@ class EquivalentExpressions { private def parameterizedBoundReferences(ctx: CodegenContext, expr: Expression): Expression = { expr.transformUp { case b: BoundReference => - val param = ctx.freshName("boundInput") + val param = ctx.freshName("ordinal") ParameterizedBoundReference(param, b.dataType, b.nullable) } } @@ -176,15 +181,17 @@ class EquivalentExpressions { } /** - * Adds the expression to structural data structure recursively. Stops if a matching expression - * is found. + * Adds the expression to structural data structure recursively. Returns false if this doesn't add + * the input expression actually. */ - def addStructuralExprTree(ctx: CodegenContext, expr: Expression): Unit = { + def addStructuralExprTree(ctx: CodegenContext, expr: Expression): Boolean = { val skip = skipExpr(expr) || expr.isInstanceOf[CodegenFallback] - if (!skip) { - addStructExpr(ctx, expr) + if (!skip && addStructExpr(ctx, expr)) { childrenToRecurse(expr).foreach(addStructuralExprTree(ctx, _)) + true + } else { + false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0c1442c44726..300f7b39762b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -168,10 +168,10 @@ abstract class Expression extends TreeNode[Expression] { } // Appends necessary parameters from sub-expression elimination. - val arguments = (ctx.subExprEliminationParameters.map(_.parameter) ++ + val arguments = (ctx.subExprEliminationParameters.map(_.ordinalParam) ++ Seq(ctx.INPUT_ROW)).mkString(", ") - val parameterString = (ctx.subExprEliminationParameters.map(p => s"int ${p.parameter}") ++ + val parameterString = (ctx.subExprEliminationParameters.map(p => s"int ${p.ordinalParam}") ++ Seq(s"InternalRow ${ctx.INPUT_ROW}")).mkString(", ") val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2a6f8650025c..4ed66861bb7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -402,8 +402,10 @@ class CodegenContext { * * equivalentExpressions will match the tree containing `col1 + col2` and it will only * be evaluated once. + * + * Visible for testing. */ - private val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + private[expressions] val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions // Foreach expression that is participating in subexpression elimination, the state to use. // Visible for testing. @@ -829,7 +831,7 @@ class CodegenContext { if (INPUT_ROW == null || currentVars != null) { expressions.mkString("\n") } else { - val structuralSubExpressionsArgs = subExprEliminationParameters.map("int" -> _.parameter) + val structuralSubExpressionsArgs = subExprEliminationParameters.map("int" -> _.ordinalParam) splitExpressions( expressions, funcName, @@ -1068,9 +1070,12 @@ class CodegenContext { * Among the set, the expressions with same semantics are replaced with a function call to the * generated function, by passing in input slots of current processing row. */ - private def structuralSubexpressionElimination(expressions: Seq[Expression]): Unit = { + private def structuralSubexpressionElimination(expressions: Seq[Expression]): Seq[Expression] = { // Add each expression tree and compute the structurally common subexpressions. - expressions.foreach(equivalentExpressions.addStructuralExprTree(this, _)) + // Those expressions are not added into structurally common subexpressions, defer them to + // semantically common subexpression. + val exprsOut = + expressions.filterNot(equivalentExpressions.addStructuralExprTree(this, _)) val structuralExprs = equivalentExpressions.getAllStructuralExpressions @@ -1089,7 +1094,7 @@ class CodegenContext { case b: ParameterizedBoundReference => b } val resultIndex = freshName("resultIndex") - val parameterString = (parameters.map(p => s"int ${p.parameter}") ++ + val parameterString = (parameters.map(p => s"int ${p.ordinalParam}") ++ Seq(s"InternalRow $INPUT_ROW", s"int $resultIndex")).mkString(", ") val fnName = freshName("subExpr") @@ -1140,6 +1145,8 @@ class CodegenContext { } } subExprEliminationParameters = Seq.empty + + exprsOut } private def semanticSubexpressionElimination(expressions: Seq[Expression]): Unit = { @@ -1194,11 +1201,12 @@ class CodegenContext { * the mapping of common subexpressions to the generated functions. */ private def subexpressionElimination(expressions: Seq[Expression]): Unit = { - if (SQLConf.get.structuralSubexpressionEliminationEnabled) { + val exprs = if (SQLConf.get.structuralSubexpressionEliminationEnabled) { structuralSubexpressionElimination(expressions) } else { - semanticSubexpressionElimination(expressions) + expressions } + semanticSubexpressionElimination(exprs) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ParameterizedBoundReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ParameterizedBoundReference.scala new file mode 100644 index 000000000000..c5110907e6c5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ParameterizedBoundReference.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.DataType + +/** + * This bound reference points to a parameterized slot in an input tuple. It is used in + * common sub-expression elimination. When some common sub-expressions have same structural + * but different slots of input tuple, we replace `BoundReference` with this parameterized + * version. The slot position is parameterized and is given at runtime. + */ +case class ParameterizedBoundReference(ordinalParam: String, dataType: DataType, nullable: Boolean) + extends LeafExpression { + + override def toString: String = s"input[$ordinalParam, ${dataType.simpleString}, $nullable]" + + override def eval(input: InternalRow): Any = { + throw new UnsupportedOperationException( + "ParameterizedBoundReference does not implement eval") + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + assert(ctx.currentVars == null && ctx.INPUT_ROW != null, + "ParameterizedBoundReference can not be used in whole-stage codegen yet.") + val javaType = JavaCode.javaType(dataType) + val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinalParam) + if (nullable) { + ev.copy(code = + code""" + |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinalParam); + |$javaType ${ev.value} = ${ev.isNull} ? + | ${CodeGenerator.defaultValue(dataType)} : ($value); + """.stripMargin) + } else { + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala index 15fc15f7cbf6..953274130529 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.types.{DataType, IntegerType} class StructuralSubexpressionEliminationSuite extends SparkFunSuite { private val ctx = new CodegenContext @@ -109,4 +108,14 @@ class StructuralSubexpressionEliminationSuite extends SparkFunSuite { assert(equivalence.getAllStructuralExpressions.values .map(_.values.count(_.size == 1)).sum == 2) } + + test("Expressions not for structural expr elimination can go non-structural mode") { + val fallback1 = CodegenFallbackExpression(Literal(1)) + val fallback2 = CodegenFallbackExpression(Literal(1)) + + val ctx = new CodegenContext() + ctx.generateExpressions(Seq(fallback1, fallback2), doSubexpressionElimination = true) + assert(ctx.equivalentExpressions.getAllStructuralExpressions.isEmpty) + assert(ctx.equivalentExpressions.getEquivalentExprs(fallback1).length == 2) + } } From f44704222e270048ba5abf0ac1224d1182b8e31f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 9 Sep 2019 21:54:42 -0700 Subject: [PATCH 4/6] Add few comment. --- .../spark/sql/catalyst/expressions/EquivalentExpressions.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 07241a54737b..13189ffacf92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -65,6 +65,9 @@ class EquivalentExpressions { private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]] // For each expression, the set of structurally equivalent expressions. + // Among expressions with same structure, there are different sub-set of expressions + // which are semantically different to each others. Thus, under each key, the value is + // the map structure used to do semantically sub-expression elimination. private val structEquivalenceMap = mutable.HashMap.empty[StructuralExpr, EquivalenceMap] /** From 71e02390d67d9671c4e3dcac586b9788d0ecd050 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 14 Sep 2019 15:44:27 -0700 Subject: [PATCH 5/6] Try again to address comment. --- .../expressions/EquivalentExpressions.scala | 33 ++++-------- .../expressions/codegen/CodeGenerator.scala | 15 ++++-- ...ucturalSubexpressionEliminationSuite.scala | 51 +++++++++---------- 3 files changed, 43 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 13189ffacf92..b50d6f97357f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -46,8 +46,8 @@ class EquivalentExpressions { case class StructuralExpr(e: Expression) { def normalized(expr: Expression): Expression = { expr.transformUp { - case b: ParameterizedBoundReference => - b.copy(ordinalParam = "") + case b: BoundReference => + b.copy(ordinal = -1) } } override def equals(o: Any): Boolean = o match { @@ -96,7 +96,7 @@ class EquivalentExpressions { * structurally equivalent expressions. Non-recursive. Returns false if this doesn't add input * expression actually. */ - def addStructExpr(ctx: CodegenContext, expr: Expression): Boolean = { + def addStructExpr(expr: Expression): Boolean = { if (expr.deterministic) { // For structural equivalent expressions, we need to pass in int type ordinals into // split functions. If the number of ordinals is more than JVM function limit, we skip @@ -108,9 +108,7 @@ class EquivalentExpressions { } val parameterLength = CodeGenerator.calculateParamLength(refs) + 2 if (CodeGenerator.isValidParamLength(parameterLength)) { - val parameterizedExpr = parameterizedBoundReferences(ctx, expr) - - val e: StructuralExpr = StructuralExpr(parameterizedExpr) + val e: StructuralExpr = StructuralExpr(expr) val f = structEquivalenceMap.get(e) if (f.isDefined) { addExpr(expr, f.get) @@ -128,17 +126,6 @@ class EquivalentExpressions { } } - /** - * Replaces bound references in given expression by parameterized bound references. - */ - private def parameterizedBoundReferences(ctx: CodegenContext, expr: Expression): Expression = { - expr.transformUp { - case b: BoundReference => - val param = ctx.freshName("ordinal") - ParameterizedBoundReference(param, b.dataType, b.nullable) - } - } - /** * Checks if we skip add sub-expressions for given expression. */ @@ -187,11 +174,11 @@ class EquivalentExpressions { * Adds the expression to structural data structure recursively. Returns false if this doesn't add * the input expression actually. */ - def addStructuralExprTree(ctx: CodegenContext, expr: Expression): Boolean = { + def addStructuralExprTree(expr: Expression): Boolean = { val skip = skipExpr(expr) || expr.isInstanceOf[CodegenFallback] - if (!skip && addStructExpr(ctx, expr)) { - childrenToRecurse(expr).foreach(addStructuralExprTree(ctx, _)) + if (!skip && addStructExpr(expr)) { + childrenToRecurse(expr).foreach(addStructuralExprTree) true } else { false @@ -213,10 +200,8 @@ class EquivalentExpressions { equivalenceMap.values.map(_.toSeq).toSeq } - def getStructurallyEquivalentExprs(ctx: CodegenContext, e: Expression): Seq[Seq[Expression]] = { - val parameterizedExpr = parameterizedBoundReferences(ctx, e) - - val key = StructuralExpr(parameterizedExpr) + def getStructurallyEquivalentExprs(e: Expression): Seq[Seq[Expression]] = { + val key = StructuralExpr(e) structEquivalenceMap.get(key).map(_.values.map(_.toSeq).toSeq).getOrElse(Seq.empty) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4ed66861bb7e..3228ff83f505 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1074,8 +1074,8 @@ class CodegenContext { // Add each expression tree and compute the structurally common subexpressions. // Those expressions are not added into structurally common subexpressions, defer them to // semantically common subexpression. - val exprsOut = - expressions.filterNot(equivalentExpressions.addStructuralExprTree(this, _)) + val nonApplicableExprs = + expressions.filterNot(equivalentExpressions.addStructuralExprTree) val structuralExprs = equivalentExpressions.getAllStructuralExpressions @@ -1090,7 +1090,12 @@ class CodegenContext { Some((key.e, exprGroups)) } }.foreach { case (expr, exprGroups) => - val parameters = expr.collect { + val parameterizedExpr = expr.transformUp { + case b: BoundReference => + val param = freshName("ordinal") + ParameterizedBoundReference(param, b.dataType, b.nullable) + } + val parameters = parameterizedExpr.collect { case b: ParameterizedBoundReference => b } val resultIndex = freshName("resultIndex") @@ -1117,7 +1122,7 @@ class CodegenContext { // Generate the code for this expression tree and wrap it in a function. // Sets the current parameters. subExprEliminationParameters = parameters - val eval = expr.genCode(this) + val eval = parameterizedExpr.genCode(this) val fn = s""" |private void $fnName($parameterString) { @@ -1146,7 +1151,7 @@ class CodegenContext { } subExprEliminationParameters = Seq.empty - exprsOut + nonApplicableExprs } private def semanticSubexpressionElimination(expressions: Seq[Expression]): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala index 953274130529..5adb66401a92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StructuralSubexpressionEliminationSuite.scala @@ -20,8 +20,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext class StructuralSubexpressionEliminationSuite extends SparkFunSuite { - private val ctx = new CodegenContext - test("Structurally Expression Equivalence") { val equivalence = new EquivalentExpressions assert(equivalence.getAllStructuralExpressions.isEmpty) @@ -29,29 +27,28 @@ class StructuralSubexpressionEliminationSuite extends SparkFunSuite { val oneA = Literal(1) val oneB = Literal(1) val twoA = Literal(2) - var twoB = Literal(2) - assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA).isEmpty) - assert(equivalence.getStructurallyEquivalentExprs(ctx, twoA).isEmpty) + assert(equivalence.getStructurallyEquivalentExprs(oneA).isEmpty) + assert(equivalence.getStructurallyEquivalentExprs(twoA).isEmpty) // Add oneA and test if it is returned. Since it is a group of one, it does not. - equivalence.addStructExpr(ctx, oneA) - assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA).size == 1) - assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA)(0).size == 1) + equivalence.addStructExpr(oneA) + assert(equivalence.getStructurallyEquivalentExprs(oneA).size == 1) + assert(equivalence.getStructurallyEquivalentExprs(oneA)(0).size == 1) - assert(equivalence.getStructurallyEquivalentExprs(ctx, twoA).isEmpty) - equivalence.addStructExpr(ctx, oneA) - assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA).size == 1) - assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA)(0).size == 2) + assert(equivalence.getStructurallyEquivalentExprs(twoA).isEmpty) + equivalence.addStructExpr(oneA) + assert(equivalence.getStructurallyEquivalentExprs(oneA).size == 1) + assert(equivalence.getStructurallyEquivalentExprs(oneA)(0).size == 2) // Add B and make sure they can see each other. - equivalence.addStructExpr(ctx, oneB) + equivalence.addStructExpr(oneB) // Use exists and reference equality because of how equals is defined. - assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA).flatten.exists(_ eq oneB)) - assert(equivalence.getStructurallyEquivalentExprs(ctx, oneA).flatten.exists(_ eq oneA)) - assert(equivalence.getStructurallyEquivalentExprs(ctx, oneB).flatten.exists(_ eq oneA)) - assert(equivalence.getStructurallyEquivalentExprs(ctx, oneB).flatten.exists(_ eq oneB)) - assert(equivalence.getStructurallyEquivalentExprs(ctx, twoA).isEmpty) + assert(equivalence.getStructurallyEquivalentExprs(oneA).flatten.exists(_ eq oneB)) + assert(equivalence.getStructurallyEquivalentExprs(oneA).flatten.exists(_ eq oneA)) + assert(equivalence.getStructurallyEquivalentExprs(oneB).flatten.exists(_ eq oneA)) + assert(equivalence.getStructurallyEquivalentExprs(oneB).flatten.exists(_ eq oneB)) + assert(equivalence.getStructurallyEquivalentExprs(twoA).isEmpty) assert(equivalence.getAllStructuralExpressions.size == 1) assert(equivalence.getAllStructuralExpressions.values.head.values.flatten.toSeq.size == 3) @@ -61,20 +58,20 @@ class StructuralSubexpressionEliminationSuite extends SparkFunSuite { val add1 = Add(oneA, oneB) val add2 = Add(oneA, oneB) - equivalence.addStructExpr(ctx, add1) - equivalence.addStructExpr(ctx, add2) + equivalence.addStructExpr(add1) + equivalence.addStructExpr(add2) assert(equivalence.getAllStructuralExpressions.size == 2) - assert(equivalence.getStructurallyEquivalentExprs(ctx, add2).flatten.exists(_ eq add1)) - assert(equivalence.getStructurallyEquivalentExprs(ctx, add2).flatten.size == 2) - assert(equivalence.getStructurallyEquivalentExprs(ctx, add1).flatten.exists(_ eq add2)) + assert(equivalence.getStructurallyEquivalentExprs(add2).flatten.exists(_ eq add1)) + assert(equivalence.getStructurallyEquivalentExprs(add2).flatten.size == 2) + assert(equivalence.getStructurallyEquivalentExprs(add1).flatten.exists(_ eq add2)) } test("Expression equivalence - non deterministic") { val sum = Add(Rand(0), Rand(0)) val equivalence = new EquivalentExpressions - equivalence.addStructExpr(ctx, sum) - equivalence.addStructExpr(ctx, sum) + equivalence.addStructExpr(sum) + equivalence.addStructExpr(sum) assert(equivalence.getAllStructuralExpressions.isEmpty) } @@ -86,7 +83,7 @@ class StructuralSubexpressionEliminationSuite extends SparkFunSuite { val add2 = Add(add, fallback) val equivalence = new EquivalentExpressions - equivalence.addStructuralExprTree(ctx, add2) + equivalence.addStructuralExprTree(add2) // `fallback` and the `add` inside should not be added assert(equivalence.getAllStructuralExpressions.values .map(_.values.count(_.size > 1)).sum == 0) @@ -100,7 +97,7 @@ class StructuralSubexpressionEliminationSuite extends SparkFunSuite { val ifExpr = If(condition, add, add) val equivalence = new EquivalentExpressions - equivalence.addStructuralExprTree(ctx, ifExpr) + equivalence.addStructuralExprTree(ifExpr) // the `add` inside `If` should not be added assert(equivalence.getAllStructuralExpressions.values .map(_.values.count(_.size > 1)).sum == 0) From 4700b89004380e48ed484311018686856df3027e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Sep 2019 13:28:35 -0700 Subject: [PATCH 6/6] Add doc to SortPrefix. --- .../spark/sql/catalyst/expressions/EquivalentExpressions.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index b50d6f97357f..cddb3416b084 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -147,6 +147,7 @@ class EquivalentExpressions { // will always get accessed. // 4. Coalesce: it's also a conditional expression, we should only recurse into the first // children, because others may not get accessed. + // 5. SortPrefix: skipt the direct child of SortPrefix which is an unevaluable SortOrder. private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match { case _: CodegenFallback => Nil case i: If => i.predicate :: Nil