diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9831b13ea754..1dc18b3a130c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -417,6 +417,10 @@ class CodegenContext extends Logging { // The collection of sub-expression result resetting methods that need to be called on each row. private val subexprFunctions = mutable.ArrayBuffer.empty[String] + // The collection of reset sub-expression, in lazy evaluation sub-expression, we should invoke + // after processing sub-expression. + private val subexprResetFunctions = mutable.ArrayBuffer.empty[String] + val outerClassName = "OuterClass" /** @@ -1012,6 +1016,14 @@ class CodegenContext extends Logging { splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) } + /** + * Returns the code for reset subexpression after splitting it if necessary. + */ + def subexprResetFunctionCode: String = { + assert(currentVars == null || subexprResetFunctions.isEmpty) + splitExpressions(subexprResetFunctions.toSeq, "subexprResetFunc_split", Seq()) + } + /** * Perform a function which generates a sequence of ExprCodes with a given mapping between * expressions and common expressions, instead of using the mapping in current context. @@ -1136,7 +1148,9 @@ class CodegenContext extends Logging { * common subexpressions, generates the functions that evaluate those expressions and populates * the mapping of common subexpressions to the generated functions. */ - private def subexpressionElimination(expressions: Seq[Expression]): Unit = { + private def subexpressionElimination( + expressions: Seq[Expression], + lazyEvaluation: Boolean = false): Unit = { // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree(_)) @@ -1145,40 +1159,107 @@ class CodegenContext extends Logging { val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) commonExprs.foreach { e => val expr = e.head - val fnName = freshName("subExpr") - val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - val value = addMutableState(javaType(expr.dataType), "subExprValue") - // Generate the code for this expression tree and wrap it in a function. val eval = expr.genCode(this) - val fn = - s""" - |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code} - | $isNull = ${eval.isNull}; - | $value = ${eval.value}; - |} + + val subExprValue = addMutableState(javaType(expr.dataType), "subExprValue") + val subExprValueIsNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + val evalSubExprValueFnName = freshName("evalSubExprValue") + + if (!lazyEvaluation) { + val fn = + s""" + |private void $evalSubExprValueFnName(InternalRow $INPUT_ROW) { + | ${eval.code} + | $subExprValueIsNull = ${eval.isNull}; + | $subExprValue = ${eval.value}; + |} """.stripMargin - // Add a state and a mapping of the common subexpressions that are associate with this - // state. Adding this expression to subExprEliminationExprMap means it will call `fn` - // when it is code generated. This decision should be a cost based one. - // - // The cost of doing subexpression elimination is: - // 1. Extra function call, although this is probably *good* as the JIT can decide to - // inline or not. - // The benefit doing subexpression elimination is: - // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 - // above. - // 2. Less code. - // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with - // at least two nodes) as the cost of doing it is expected to be low. - - subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState( - JavaCode.isNullGlobal(isNull), - JavaCode.global(value, expr.dataType)) - subExprEliminationExprs ++= e.map(_ -> state).toMap + // Add a state and a mapping of the common subexpressions that are associate with this + // state. Adding this expression to subExprEliminationExprMap means it will call `fn` + // when it is code generated. This decision should be a cost based one. + // + // The cost of doing subexpression elimination is: + // 1. Extra function call, although this is probably *good* as the JIT can decide to + // inline or not. + // The benefit doing subexpression elimination is: + // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 + // above. + // 2. Less code. + // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with + // at least two nodes) as the cost of doing it is expected to be low. + + subexprFunctions += s"${addNewFunction(evalSubExprValueFnName, fn)}($INPUT_ROW);" + val state = SubExprEliminationState( + JavaCode.isNullGlobal(subExprValueIsNull), + JavaCode.global(subExprValue, expr.dataType)) + subExprEliminationExprs ++= e.map(_ -> state).toMap + } else { + + // the variable to check if a subexpression is evaluated or not. + val isSubExprEval = addMutableState(JAVA_BOOLEAN, "isSubExprEval") + + val evalSubExprValueFnName = freshName("evalSubExprValue") + val evalSubExprValueFn = + s""" + |private void $evalSubExprValueFnName(InternalRow ${INPUT_ROW}) { + | ${eval.code} + | $subExprValueIsNull = ${eval.isNull}; + | $subExprValue = ${eval.value}; + | $isSubExprEval = true; + |} + |""".stripMargin + + val splitEvalSubExprValueFnName = splitExpressions(Seq( + s"${addNewFunction(evalSubExprValueFnName, evalSubExprValueFn)}($INPUT_ROW)"), + s"${evalSubExprValueFnName}_split", Seq("InternalRow" -> INPUT_ROW)) + + val getSubExprValueFnName = freshName("getSubExprValue") + val getSubExprValueFn = + s""" + |private ${boxedType(expr.dataType)} $getSubExprValueFnName(InternalRow ${INPUT_ROW}) { + | if (!$isSubExprEval) { + | $splitEvalSubExprValueFnName; + | } + | return ${subExprValue}; + |} + |""".stripMargin + + val getSubExprValueIsNullFnName = freshName("getSubExprValueIsNull") + val getSubExprValueIsNullFn = + s""" + |private boolean ${getSubExprValueIsNullFnName}(InternalRow ${INPUT_ROW}) { + | if (!$isSubExprEval) { + | $splitEvalSubExprValueFnName; + | } + | return $subExprValueIsNull; + |} + |""".stripMargin + + // the function for reset subexpression after processing. + val resetFnName = freshName("resetSubExpr") + val resetFn = + s""" + |private void $resetFnName() { + | $isSubExprEval = false; + |} + |""".stripMargin + + subexprResetFunctions += s"${addNewFunction(resetFnName, resetFn)}();" + + val splitIsNull = splitExpressions(Seq( + s"${addNewFunction(getSubExprValueIsNullFnName, getSubExprValueIsNullFn)}($INPUT_ROW)"), + s"${getSubExprValueIsNullFnName}_split", Seq("InternalRow" -> INPUT_ROW)) + val splitValue = splitExpressions( + Seq(s"${addNewFunction(getSubExprValueFnName, getSubExprValueFn)}($INPUT_ROW)"), + s"${getSubExprValueFnName}_split", Seq("InternalRow" -> INPUT_ROW)) + val state = SubExprEliminationState( + JavaCode.isNullGlobal(splitIsNull), + JavaCode.global(splitValue, expr.dataType)) + + subExprEliminationExprs ++= e.map(_ -> state).toMap + } } } @@ -1189,8 +1270,9 @@ class CodegenContext extends Logging { */ def generateExpressions( expressions: Seq[Expression], - doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { - if (doSubexpressionElimination) subexpressionElimination(expressions) + doSubexpressionElimination: Boolean = false, + lazyEvalSubexpression: Boolean = false): Seq[ExprCode] = { + if (doSubexpressionElimination) subexpressionElimination(expressions, lazyEvalSubexpression) expressions.map(e => e.genCode(this)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 2e018de07101..743d265be923 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -42,26 +42,31 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP expressions: Seq[Expression], inputSchema: Seq[Attribute], useSubexprElimination: Boolean): MutableProjection = { - create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination) + create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination, false) } - def generate(expressions: Seq[Expression], useSubexprElimination: Boolean): MutableProjection = { - create(canonicalize(expressions), useSubexprElimination) + def generate( + expressions: Seq[Expression], + useSubexprElimination: Boolean, + lazyEvaluation: Boolean = false): MutableProjection = { + create(canonicalize(expressions), useSubexprElimination, lazyEvaluation) } protected def create(expressions: Seq[Expression]): MutableProjection = { - create(expressions, false) + create(expressions, false, false) } private def create( expressions: Seq[Expression], - useSubexprElimination: Boolean): MutableProjection = { + useSubexprElimination: Boolean, + lazyEvaluation: Boolean): MutableProjection = { val ctx = newCodeGenContext() val validExpr = expressions.zipWithIndex.filter { case (NoOp, _) => false case _ => true } - val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) + val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination, + lazyEvaluation) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index c246d07f189b..3d64ac73f726 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -38,8 +38,9 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { val ctx = newCodeGenContext() // Do sub-expression elimination for predicates. - val eval = ctx.generateExpressions(Seq(predicate), useSubexprElimination).head - val evalSubexpr = ctx.subexprFunctionsCode + val eval = + ctx.generateExpressions(Seq(predicate), useSubexprElimination, useSubexprElimination).head + val subExprReset = ctx.subexprResetFunctionCode val codeBody = s""" public SpecificPredicate generate(Object[] references) { @@ -60,9 +61,10 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { } public boolean eval(InternalRow ${ctx.INPUT_ROW}) { - $evalSubexpr ${eval.code} - return !${eval.isNull} && ${eval.value}; + boolean result = !${eval.isNull} && ${eval.value}; + $subExprReset; + return result; } ${ctx.declareAddedFunctions()} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 459c1d9a8ba1..ab4c438c62dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -286,8 +286,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def createCode( ctx: CodegenContext, expressions: Seq[Expression], - useSubexprElimination: Boolean = false): ExprCode = { - val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) + useSubexprElimination: Boolean = false, + lazyEvaluation: Boolean = false): ExprCode = { + val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination, lazyEvaluation) val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable)) val numVarLenFields = exprSchemas.count { @@ -323,19 +324,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def generate( expressions: Seq[Expression], - subexpressionEliminationEnabled: Boolean): UnsafeProjection = { - create(canonicalize(expressions), subexpressionEliminationEnabled) + subexpressionEliminationEnabled: Boolean, + lazyEvaluation: Boolean = false): UnsafeProjection = { + create(canonicalize(expressions), subexpressionEliminationEnabled, lazyEvaluation) } protected def create(references: Seq[Expression]): UnsafeProjection = { - create(references, subexpressionEliminationEnabled = false) + create(references, subexpressionEliminationEnabled = false, lazyEvaluation = false) } private def create( expressions: Seq[Expression], - subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + subexpressionEliminationEnabled: Boolean, + lazyEvaluation: Boolean): UnsafeProjection = { val ctx = newCodeGenContext() - val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) + val eval = createCode(ctx, expressions, subexpressionEliminationEnabled, false) val codeBody = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 44b6aa6b6271..fedb67876b53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -535,8 +535,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { Add(BoundReference(colIndex, DoubleType, true), BoundReference(numOfExprs + colIndex, DoubleType, true)))) // these should not fail to compile due to 64K limit - GenerateUnsafeProjection.generate(exprs, true) - GenerateMutableProjection.generate(exprs, true) + GenerateUnsafeProjection.generate(exprs, true, false) + GenerateMutableProjection.generate(exprs, true, false) + GenerateUnsafeProjection.generate(exprs, true, true) + GenerateMutableProjection.generate(exprs, true, true) } test("SPARK-32624: Use CodeGenerator.typeName() to fix byte[] compile issue") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3e810a453377..d125ebc763ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2907,6 +2907,35 @@ class DataFrameSuite extends QueryTest } } } + + test("SPARK-35688: subexpressions should be lazy evaluation in GeneratePredicate") { + withTempPath { dir => + Seq( + ("true", "false"), + ("false", "true"), + ("false", "false"), + ("true", "true") + ).foreach { case (subExprEliminationEnabled, codegenEnabled) => + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> subExprEliminationEnabled, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled, + "spark.sql.ansi.enabled" -> "true") { + Seq( + (1 to 10).toArray, + (1 to 5).toArray + ).toDF("c1") + .write + .mode("overwrite") + .save(dir.getCanonicalPath) + val df = spark.read.load(dir.getCanonicalPath) + .filter("size(c1) > 5 and (element_at(c1, 7) = 8 or element_at(c1, 7) = 7)") + checkAnswer( + df, Row((1 to 10).toArray) :: Nil + ) + } + } + } + } } case class GroupByKey(a: Int, b: Int)