Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,18 @@ case class CaseWhen(
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// This variable represents whether the first successful condition is met or not.
// It is initialized to `false` and it is set to `true` when the first condition which
// evaluates to `true` is met and therefore is not needed to go on anymore on the computation
// of the following conditions.
val conditionMet = ctx.freshName("caseWhenConditionMet")
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
ctx.addMutableState(ctx.javaType(dataType), ev.value)
// This variable holds the state of the result:
Copy link
Contributor

@cloud-fan cloud-fan Dec 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resultState holds the state of the result, which has 3 possible values:

// -1 means the condition is not met yet and the result is unknown.
val NOT_MATCHED = -1
// 0 means the condition is met and result is not null.
val HAS_NONNULL = 0
// 1 means the condition is met and result is null.
val HAS_NULL = 1
// It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
// We won't go on anymore on the computation.
val resultState = ctx.freshName("caseWhenResultState")
val tmpResult = ctx.freshName("caseWhenTmpResult")
ctx.addMutableState(ctx.javaType(dataType), tmpResult)

// these blocks are meant to be inside a
// do {
Expand All @@ -200,9 +205,8 @@ case class CaseWhen(
|${cond.code}
|if (!${cond.isNull} && ${cond.value}) {
| ${res.code}
| ${ev.isNull} = ${res.isNull};
| ${ev.value} = ${res.value};
| $conditionMet = true;
| $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
| $tmpResult = ${res.value};
| continue;
|}
""".stripMargin
Expand All @@ -212,59 +216,63 @@ case class CaseWhen(
val res = elseExpr.genCode(ctx)
s"""
|${res.code}
|${ev.isNull} = ${res.isNull};
|${ev.value} = ${res.value};
|$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
|$tmpResult = ${res.value};
""".stripMargin
}

val allConditions = cases ++ elseCode

// This generates code like:
// conditionMet = caseWhen_1(i);
// if(conditionMet) {
// caseWhenResultState = caseWhen_1(i);
// if(caseWhenResultState != -1) {
// continue;
// }
// conditionMet = caseWhen_2(i);
// if(conditionMet) {
// caseWhenResultState = caseWhen_2(i);
// if(caseWhenResultState != -1) {
// continue;
// }
// ...
// and the declared methods are:
// private boolean caseWhen_1234() {
// boolean conditionMet = false;
// private byte caseWhen_1234() {
// byte caseWhenResultState = -1;
// do {
// // here the evaluation of the conditions
// } while (false);
// return conditionMet;
// return caseWhenResultState;
// }
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = allConditions,
funcName = "caseWhen",
returnType = ctx.JAVA_BOOLEAN,
returnType = ctx.JAVA_BYTE,
makeSplitFunction = func =>
s"""
|${ctx.JAVA_BOOLEAN} $conditionMet = false;
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $func
|} while (false);
|return $conditionMet;
|return $resultState;
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
|$conditionMet = $funcCall;
|if ($conditionMet) {
|$resultState = $funcCall;
|if ($resultState != $NOT_MATCHED) {
| continue;
|}
""".stripMargin
}.mkString)

ev.copy(code = s"""
${ev.isNull} = true;
${ev.value} = ${ctx.defaultValue(dataType)};
${ctx.JAVA_BOOLEAN} $conditionMet = false;
do {
$codes
} while (false);""")
ev.copy(code =
s"""
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
|$tmpResult = ${ctx.defaultValue(dataType)};
|do {
| $codes
|} while (false);
|// TRUE if any condition is met and the result is null, or no any condition is met.
|final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL);
|final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult;
""".stripMargin)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,35 +72,39 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
ctx.addMutableState(ctx.javaType(dataType), ev.value)
val tmpIsNull = ctx.freshName("coalesceTmpIsNull")
ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull)

// all the evals are meant to be in a do { ... } while (false); loop
val evals = children.map { e =>
val eval = e.genCode(ctx)
s"""
|${eval.code}
|if (!${eval.isNull}) {
| ${ev.isNull} = false;
| $tmpIsNull = false;
| ${ev.value} = ${eval.value};
| continue;
|}
""".stripMargin
}

val resultType = ctx.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "coalesce",
returnType = resultType,
makeSplitFunction = func =>
s"""
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
|do {
| $func
|} while (false);
|return ${ev.value};
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
|$funcCall;
|if (!${ev.isNull}) {
|${ev.value} = $funcCall;
|if (!$tmpIsNull) {
| continue;
|}
""".stripMargin
Expand All @@ -109,11 +113,12 @@ case class Coalesce(children: Seq[Expression]) extends Expression {

ev.copy(code =
s"""
|${ev.isNull} = true;
|${ev.value} = ${ctx.defaultValue(dataType)};
|$tmpIsNull = true;
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
|do {
| $codes
|} while (false);
|final boolean ${ev.isNull} = $tmpIsNull;
""".stripMargin)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,37 +237,44 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
val javaDataType = ctx.javaType(value.dataType)
val valueGen = value.genCode(ctx)
val listGen = list.map(_.genCode(ctx))
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value)
ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
// inTmpResult has 3 possible values:
// -1 means no matches found and there is at least one value in the list evaluated to null
val HAS_NULL = -1
// 0 means no matches found and all values in the list are not null
val NOT_MATCHED = 0
// 1 means one value in the list is matched
val MATCHED = 1
val tmpResult = ctx.freshName("inTmpResult")
val valueArg = ctx.freshName("valueArg")
// All the blocks are meant to be inside a do { ... } while (false); loop.
// The evaluation of variables can be stopped when we find a matching value.
val listCode = listGen.map(x =>
s"""
|${x.code}
|if (${x.isNull}) {
| ${ev.isNull} = true;
| $tmpResult = $HAS_NULL; // ${ev.isNull} = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// ${ev.isNull} = true; looks not necessary comment for me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, seems suggested by other reviewer. Then fine for me.

|} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
| ${ev.isNull} = false;
| ${ev.value} = true;
| $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

| continue;
|}
""".stripMargin)

val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = listCode,
funcName = "valueIn",
extraArguments = (javaDataType, valueArg) :: Nil,
extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't read tmpResult in the evaluation code, which indicates that this doesn't need to be a parameter, we can make it a local variable in the split method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tmpResult is read in the evaluation code if any of two condition is not satisfied in the evaluation code (i.e. assignment has not been performed).

byte splitFunc_0(..., byte $tmpResult) {
  do {
    ${x.code}
    if (${x.isNull}) {
      $tmpResult = -1;
    } else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
      $tmpResult = 1;
      continue;
    }
  } while (false);
  return $tmpResult;
}
...
$tmpResult = splitFunc_0(..., $tmpResult, ...);
if ($tmpResult == 1) { continue; }

Since $tmpResult = $funcCall; always updates $tmpResults, we may have to pass $tmpResult to a function to keep the previous value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to keep the previous value? it's not read in the method. My proposal is:

byte splitFunc_0(...) {
  byte $tmpResult = 0;
  do {
    ${x.code}
    if (${x.isNull}) {
      $tmpResult = -1;
    } else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) {
      $tmpResult = 1;
      continue;
    }
  } while (false);
  return $tmpResult;
}
...
$tmpResult = splitFunc_0(...);
if ($tmpResult == 1) { continue; }

Copy link
Member Author

@kiszk kiszk Dec 6, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your proposal makes incorrect result in this case
In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType), Literal(2)))

This result should be null. This is because the evaluation of the first null assigns -1 into $tmpResult, and the evaluation for the second 2 does not change $tmpResult.
However, your proposal returns false since byte $tmpResult = 0; always make result false if any of two conditions are satisfied.

WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah got it, actually the tmpResult is read in the method, by return $tmpResult;. If we don't change tmpResult in the method, we need to return the previous value. Sorry for the trouble.

returnType = ctx.JAVA_BYTE,
makeSplitFunction = body =>
s"""
|do {
| $body
|} while (false);
|return $tmpResult;
""".stripMargin,
foldFunctions = _.map { funcCall =>
s"""
|$funcCall;
|if (${ev.value}) {
|$tmpResult = $funcCall;
|if ($tmpResult == $MATCHED) {
| continue;
|}
""".stripMargin
Expand All @@ -276,14 +283,16 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
ev.copy(code =
s"""
|${valueGen.code}
|${ev.value} = false;
|${ev.isNull} = ${valueGen.isNull};
|if (!${ev.isNull}) {
|byte $tmpResult = $HAS_NULL;
|if (!${valueGen.isNull}) {
| $tmpResult = 0;
| $javaDataType $valueArg = ${valueGen.value};
| do {
| $codes
| } while (false);
|}
|final boolean ${ev.isNull} = ($tmpResult == $HAS_NULL);
|final boolean ${ev.value} = ($tmpResult == $MATCHED);
""".stripMargin)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types._

class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -145,4 +146,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
IndexedSeq((Literal(12) === Literal(1), Literal(42)),
(Literal(12) === Literal(42), Literal(1))))
}

test("SPARK-22705: case when should use less global variables") {
val ctx = new CodegenContext()
CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx)
assert(ctx.mutableStates.size == 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -155,6 +156,12 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Coalesce(inputs), "x_1")
}

test("SPARK-22705: Coalesce should use less global variables") {
val ctx = new CodegenContext()
Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx)
assert(ctx.mutableStates.size == 1)
}

test("AtLeastNNonNulls should not throw 64kb exception") {
val inputs = (1 to 4000).map(x => Literal(s"x_$x"))
checkEvaluation(AtLeastNNonNulls(1, inputs), true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -245,6 +246,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(In(Literal(1.0D), sets), true)
}

test("SPARK-22705: In should use less global variables") {
val ctx = new CodegenContext()
In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx)
assert(ctx.mutableStates.isEmpty)
}

test("INSET") {
val hS = HashSet[Any]() + 1 + 2
val nS = HashSet[Any]() + 1 + 2 + null
Expand Down