diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ae5f7140847d..53c3b226895e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -180,13 +180,18 @@ case class CaseWhen( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // This variable represents whether the first successful condition is met or not. - // It is initialized to `false` and it is set to `true` when the first condition which - // evaluates to `true` is met and therefore is not needed to go on anymore on the computation - // of the following conditions. - val conditionMet = ctx.freshName("caseWhenConditionMet") - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) + // This variable holds the state of the result: + // -1 means the condition is not met yet and the result is unknown. + val NOT_MATCHED = -1 + // 0 means the condition is met and result is not null. + val HAS_NONNULL = 0 + // 1 means the condition is met and result is null. + val HAS_NULL = 1 + // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, + // We won't go on anymore on the computation. + val resultState = ctx.freshName("caseWhenResultState") + val tmpResult = ctx.freshName("caseWhenTmpResult") + ctx.addMutableState(ctx.javaType(dataType), tmpResult) // these blocks are meant to be inside a // do { @@ -200,9 +205,8 @@ case class CaseWhen( |${cond.code} |if (!${cond.isNull} && ${cond.value}) { | ${res.code} - | ${ev.isNull} = ${res.isNull}; - | ${ev.value} = ${res.value}; - | $conditionMet = true; + | $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); + | $tmpResult = ${res.value}; | continue; |} """.stripMargin @@ -212,59 +216,63 @@ case class CaseWhen( val res = elseExpr.genCode(ctx) s""" |${res.code} - |${ev.isNull} = ${res.isNull}; - |${ev.value} = ${res.value}; + |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); + |$tmpResult = ${res.value}; """.stripMargin } val allConditions = cases ++ elseCode // This generates code like: - // conditionMet = caseWhen_1(i); - // if(conditionMet) { + // caseWhenResultState = caseWhen_1(i); + // if(caseWhenResultState != -1) { // continue; // } - // conditionMet = caseWhen_2(i); - // if(conditionMet) { + // caseWhenResultState = caseWhen_2(i); + // if(caseWhenResultState != -1) { // continue; // } // ... // and the declared methods are: - // private boolean caseWhen_1234() { - // boolean conditionMet = false; + // private byte caseWhen_1234() { + // byte caseWhenResultState = -1; // do { // // here the evaluation of the conditions // } while (false); - // return conditionMet; + // return caseWhenResultState; // } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, funcName = "caseWhen", - returnType = ctx.JAVA_BOOLEAN, + returnType = ctx.JAVA_BYTE, makeSplitFunction = func => s""" - |${ctx.JAVA_BOOLEAN} $conditionMet = false; + |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $func |} while (false); - |return $conditionMet; + |return $resultState; """.stripMargin, foldFunctions = _.map { funcCall => s""" - |$conditionMet = $funcCall; - |if ($conditionMet) { + |$resultState = $funcCall; + |if ($resultState != $NOT_MATCHED) { | continue; |} """.stripMargin }.mkString) - ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ctx.JAVA_BOOLEAN} $conditionMet = false; - do { - $codes - } while (false);""") + ev.copy(code = + s""" + |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |$tmpResult = ${ctx.defaultValue(dataType)}; + |do { + | $codes + |} while (false); + |// TRUE if any condition is met and the result is null, or no any condition is met. + |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL); + |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult; + """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 26c9a41efc9f..294cdcb2e954 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) + val tmpIsNull = ctx.freshName("coalesceTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -81,26 +81,30 @@ case class Coalesce(children: Seq[Expression]) extends Expression { s""" |${eval.code} |if (!${eval.isNull}) { - | ${ev.isNull} = false; + | $tmpIsNull = false; | ${ev.value} = ${eval.value}; | continue; |} """.stripMargin } + val resultType = ctx.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", + returnType = resultType, makeSplitFunction = func => s""" + |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $func |} while (false); + |return ${ev.value}; """.stripMargin, foldFunctions = _.map { funcCall => s""" - |$funcCall; - |if (!${ev.isNull}) { + |${ev.value} = $funcCall; + |if (!$tmpIsNull) { | continue; |} """.stripMargin @@ -109,11 +113,12 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" - |${ev.isNull} = true; - |${ev.value} = ${ctx.defaultValue(dataType)}; + |$tmpIsNull = true; + |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); + |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 04e669492ec6..7445b657f988 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -237,8 +237,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val javaDataType = ctx.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + // inTmpResult has 3 possible values: + // -1 means no matches found and there is at least one value in the list evaluated to null + val HAS_NULL = -1 + // 0 means no matches found and all values in the list are not null + val NOT_MATCHED = 0 + // 1 means one value in the list is matched + val MATCHED = 1 + val tmpResult = ctx.freshName("inTmpResult") val valueArg = ctx.freshName("valueArg") // All the blocks are meant to be inside a do { ... } while (false); loop. // The evaluation of variables can be stopped when we find a matching value. @@ -246,10 +252,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { s""" |${x.code} |if (${x.isNull}) { - | ${ev.isNull} = true; + | $tmpResult = $HAS_NULL; // ${ev.isNull} = true; |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { - | ${ev.isNull} = false; - | ${ev.value} = true; + | $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true; | continue; |} """.stripMargin) @@ -257,17 +262,19 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: Nil, + extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil, + returnType = ctx.JAVA_BYTE, makeSplitFunction = body => s""" |do { | $body |} while (false); + |return $tmpResult; """.stripMargin, foldFunctions = _.map { funcCall => s""" - |$funcCall; - |if (${ev.value}) { + |$tmpResult = $funcCall; + |if ($tmpResult == $MATCHED) { | continue; |} """.stripMargin @@ -276,14 +283,16 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { ev.copy(code = s""" |${valueGen.code} - |${ev.value} = false; - |${ev.isNull} = ${valueGen.isNull}; - |if (!${ev.isNull}) { + |byte $tmpResult = $HAS_NULL; + |if (!${valueGen.isNull}) { + | $tmpResult = 0; | $javaDataType $valueArg = ${valueGen.value}; | do { | $codes | } while (false); |} + |final boolean ${ev.isNull} = ($tmpResult == $HAS_NULL); + |final boolean ${ev.value} = ($tmpResult == $MATCHED); """.stripMargin) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 3e11c3d2d4fe..60d84aae1fa3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -145,4 +146,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper IndexedSeq((Literal(12) === Literal(1), Literal(42)), (Literal(12) === Literal(42), Literal(1)))) } + + test("SPARK-22705: case when should use less global variables") { + val ctx = new CodegenContext() + CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx) + assert(ctx.mutableStates.size == 1) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 40ef7770da33..a23cd9563277 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.types._ @@ -155,6 +156,12 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Coalesce(inputs), "x_1") } + test("SPARK-22705: Coalesce should use less global variables") { + val ctx = new CodegenContext() + Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx) + assert(ctx.mutableStates.size == 1) + } + test("AtLeastNNonNulls should not throw 64kb exception") { val inputs = (1 to 4000).map(x => Literal(s"x_$x")) checkEvaluation(AtLeastNNonNulls(1, inputs), true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 0079e4e8d6f7..c85d24dd245d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -245,6 +246,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(In(Literal(1.0D), sets), true) } + test("SPARK-22705: In should use less global variables") { + val ctx = new CodegenContext() + In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } + test("INSET") { val hS = HashSet[Any]() + 1 + 2 val nS = HashSet[Any]() + 1 + 2 + null