Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,10 @@ class CodegenContext extends Logging {
// The collection of sub-expression result resetting methods that need to be called on each row.
private val subexprFunctions = mutable.ArrayBuffer.empty[String]

// The collection of reset sub-expression, in lazy evaluation sub-expression, we should invoke
// after processing sub-expression.
private val subexprResetFunctions = mutable.ArrayBuffer.empty[String]

val outerClassName = "OuterClass"

/**
Expand Down Expand Up @@ -1012,6 +1016,14 @@ class CodegenContext extends Logging {
splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW))
}

/**
* Returns the code for reset subexpression after splitting it if necessary.
*/
def subexprResetFunctionCode: String = {
assert(currentVars == null || subexprResetFunctions.isEmpty)
splitExpressions(subexprResetFunctions.toSeq, "subexprResetFunc_split", Seq())
}

/**
* Perform a function which generates a sequence of ExprCodes with a given mapping between
* expressions and common expressions, instead of using the mapping in current context.
Expand Down Expand Up @@ -1136,7 +1148,9 @@ class CodegenContext extends Logging {
* common subexpressions, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
*/
private def subexpressionElimination(expressions: Seq[Expression]): Unit = {
private def subexpressionElimination(
expressions: Seq[Expression],
lazyEvaluation: Boolean = false): Unit = {
// Add each expression tree and compute the common subexpressions.
expressions.foreach(equivalentExpressions.addExprTree(_))

Expand All @@ -1145,40 +1159,107 @@ class CodegenContext extends Logging {
val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
commonExprs.foreach { e =>
val expr = e.head
val fnName = freshName("subExpr")
val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
val value = addMutableState(javaType(expr.dataType), "subExprValue")

// Generate the code for this expression tree and wrap it in a function.
val eval = expr.genCode(this)
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
| ${eval.code}
| $isNull = ${eval.isNull};
| $value = ${eval.value};
|}

val subExprValue = addMutableState(javaType(expr.dataType), "subExprValue")
val subExprValueIsNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
val evalSubExprValueFnName = freshName("evalSubExprValue")

if (!lazyEvaluation) {
val fn =
s"""
|private void $evalSubExprValueFnName(InternalRow $INPUT_ROW) {
| ${eval.code}
| $subExprValueIsNull = ${eval.isNull};
| $subExprValue = ${eval.value};
|}
""".stripMargin

// Add a state and a mapping of the common subexpressions that are associate with this
// state. Adding this expression to subExprEliminationExprMap means it will call `fn`
// when it is code generated. This decision should be a cost based one.
//
// The cost of doing subexpression elimination is:
// 1. Extra function call, although this is probably *good* as the JIT can decide to
// inline or not.
// The benefit doing subexpression elimination is:
// 1. Running the expression logic. Even for a simple expression, it is likely more than 3
// above.
// 2. Less code.
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
// at least two nodes) as the cost of doing it is expected to be low.

subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState(
JavaCode.isNullGlobal(isNull),
JavaCode.global(value, expr.dataType))
subExprEliminationExprs ++= e.map(_ -> state).toMap
// Add a state and a mapping of the common subexpressions that are associate with this
// state. Adding this expression to subExprEliminationExprMap means it will call `fn`
// when it is code generated. This decision should be a cost based one.
//
// The cost of doing subexpression elimination is:
// 1. Extra function call, although this is probably *good* as the JIT can decide to
// inline or not.
// The benefit doing subexpression elimination is:
// 1. Running the expression logic. Even for a simple expression, it is likely more than 3
// above.
// 2. Less code.
// Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with
// at least two nodes) as the cost of doing it is expected to be low.

subexprFunctions += s"${addNewFunction(evalSubExprValueFnName, fn)}($INPUT_ROW);"
val state = SubExprEliminationState(
JavaCode.isNullGlobal(subExprValueIsNull),
JavaCode.global(subExprValue, expr.dataType))
subExprEliminationExprs ++= e.map(_ -> state).toMap
} else {

// the variable to check if a subexpression is evaluated or not.
val isSubExprEval = addMutableState(JAVA_BOOLEAN, "isSubExprEval")

val evalSubExprValueFnName = freshName("evalSubExprValue")
val evalSubExprValueFn =
s"""
|private void $evalSubExprValueFnName(InternalRow ${INPUT_ROW}) {
| ${eval.code}
| $subExprValueIsNull = ${eval.isNull};
| $subExprValue = ${eval.value};
| $isSubExprEval = true;
|}
|""".stripMargin

val splitEvalSubExprValueFnName = splitExpressions(Seq(
s"${addNewFunction(evalSubExprValueFnName, evalSubExprValueFn)}($INPUT_ROW)"),
s"${evalSubExprValueFnName}_split", Seq("InternalRow" -> INPUT_ROW))

val getSubExprValueFnName = freshName("getSubExprValue")
val getSubExprValueFn =
s"""
|private ${boxedType(expr.dataType)} $getSubExprValueFnName(InternalRow ${INPUT_ROW}) {
| if (!$isSubExprEval) {
| $splitEvalSubExprValueFnName;
| }
| return ${subExprValue};
|}
|""".stripMargin

val getSubExprValueIsNullFnName = freshName("getSubExprValueIsNull")
val getSubExprValueIsNullFn =
s"""
|private boolean ${getSubExprValueIsNullFnName}(InternalRow ${INPUT_ROW}) {
| if (!$isSubExprEval) {
| $splitEvalSubExprValueFnName;
| }
| return $subExprValueIsNull;
|}
|""".stripMargin

// the function for reset subexpression after processing.
val resetFnName = freshName("resetSubExpr")
val resetFn =
s"""
|private void $resetFnName() {
| $isSubExprEval = false;
|}
|""".stripMargin

subexprResetFunctions += s"${addNewFunction(resetFnName, resetFn)}();"

val splitIsNull = splitExpressions(Seq(
s"${addNewFunction(getSubExprValueIsNullFnName, getSubExprValueIsNullFn)}($INPUT_ROW)"),
s"${getSubExprValueIsNullFnName}_split", Seq("InternalRow" -> INPUT_ROW))
val splitValue = splitExpressions(
Seq(s"${addNewFunction(getSubExprValueFnName, getSubExprValueFn)}($INPUT_ROW)"),
s"${getSubExprValueFnName}_split", Seq("InternalRow" -> INPUT_ROW))
val state = SubExprEliminationState(
JavaCode.isNullGlobal(splitIsNull),
JavaCode.global(splitValue, expr.dataType))

subExprEliminationExprs ++= e.map(_ -> state).toMap
}
}
}

Expand All @@ -1189,8 +1270,9 @@ class CodegenContext extends Logging {
*/
def generateExpressions(
expressions: Seq[Expression],
doSubexpressionElimination: Boolean = false): Seq[ExprCode] = {
if (doSubexpressionElimination) subexpressionElimination(expressions)
doSubexpressionElimination: Boolean = false,
lazyEvalSubexpression: Boolean = false): Seq[ExprCode] = {
if (doSubexpressionElimination) subexpressionElimination(expressions, lazyEvalSubexpression)
expressions.map(e => e.genCode(this))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,31 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
expressions: Seq[Expression],
inputSchema: Seq[Attribute],
useSubexprElimination: Boolean): MutableProjection = {
create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination)
create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination, false)
}

def generate(expressions: Seq[Expression], useSubexprElimination: Boolean): MutableProjection = {
create(canonicalize(expressions), useSubexprElimination)
def generate(
expressions: Seq[Expression],
useSubexprElimination: Boolean,
lazyEvaluation: Boolean = false): MutableProjection = {
create(canonicalize(expressions), useSubexprElimination, lazyEvaluation)
}

protected def create(expressions: Seq[Expression]): MutableProjection = {
create(expressions, false)
create(expressions, false, false)
}

private def create(
expressions: Seq[Expression],
useSubexprElimination: Boolean): MutableProjection = {
useSubexprElimination: Boolean,
lazyEvaluation: Boolean): MutableProjection = {
val ctx = newCodeGenContext()
val validExpr = expressions.zipWithIndex.filter {
case (NoOp, _) => false
case _ => true
}
val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination)
val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination,
lazyEvaluation)

// 4-tuples: (code for projection, isNull variable name, value variable name, column index)
val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] {
val ctx = newCodeGenContext()

// Do sub-expression elimination for predicates.
val eval = ctx.generateExpressions(Seq(predicate), useSubexprElimination).head
val evalSubexpr = ctx.subexprFunctionsCode
val eval =
ctx.generateExpressions(Seq(predicate), useSubexprElimination, useSubexprElimination).head
val subExprReset = ctx.subexprResetFunctionCode

val codeBody = s"""
public SpecificPredicate generate(Object[] references) {
Expand All @@ -60,9 +61,10 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] {
}

public boolean eval(InternalRow ${ctx.INPUT_ROW}) {
$evalSubexpr
${eval.code}
return !${eval.isNull} && ${eval.value};
boolean result = !${eval.isNull} && ${eval.value};
$subExprReset;
return result;
}

${ctx.declareAddedFunctions()}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
def createCode(
ctx: CodegenContext,
expressions: Seq[Expression],
useSubexprElimination: Boolean = false): ExprCode = {
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
useSubexprElimination: Boolean = false,
lazyEvaluation: Boolean = false): ExprCode = {
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination, lazyEvaluation)
val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable))

val numVarLenFields = exprSchemas.count {
Expand Down Expand Up @@ -323,19 +324,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro

def generate(
expressions: Seq[Expression],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
create(canonicalize(expressions), subexpressionEliminationEnabled)
subexpressionEliminationEnabled: Boolean,
lazyEvaluation: Boolean = false): UnsafeProjection = {
create(canonicalize(expressions), subexpressionEliminationEnabled, lazyEvaluation)
}

protected def create(references: Seq[Expression]): UnsafeProjection = {
create(references, subexpressionEliminationEnabled = false)
create(references, subexpressionEliminationEnabled = false, lazyEvaluation = false)
}

private def create(
expressions: Seq[Expression],
subexpressionEliminationEnabled: Boolean): UnsafeProjection = {
subexpressionEliminationEnabled: Boolean,
lazyEvaluation: Boolean): UnsafeProjection = {
val ctx = newCodeGenContext()
val eval = createCode(ctx, expressions, subexpressionEliminationEnabled)
val eval = createCode(ctx, expressions, subexpressionEliminationEnabled, false)

val codeBody =
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
Add(BoundReference(colIndex, DoubleType, true),
BoundReference(numOfExprs + colIndex, DoubleType, true))))
// these should not fail to compile due to 64K limit
GenerateUnsafeProjection.generate(exprs, true)
GenerateMutableProjection.generate(exprs, true)
GenerateUnsafeProjection.generate(exprs, true, false)
GenerateMutableProjection.generate(exprs, true, false)
GenerateUnsafeProjection.generate(exprs, true, true)
GenerateMutableProjection.generate(exprs, true, true)
}

test("SPARK-32624: Use CodeGenerator.typeName() to fix byte[] compile issue") {
Expand Down
29 changes: 29 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2907,6 +2907,35 @@ class DataFrameSuite extends QueryTest
}
}
}

test("SPARK-35688: subexpressions should be lazy evaluation in GeneratePredicate") {
withTempPath { dir =>
Seq(
("true", "false"),
("false", "true"),
("false", "false"),
("true", "true")
).foreach { case (subExprEliminationEnabled, codegenEnabled) =>
withSQLConf(
SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> subExprEliminationEnabled,
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled,
"spark.sql.ansi.enabled" -> "true") {
Seq(
(1 to 10).toArray,
(1 to 5).toArray
).toDF("c1")
.write
.mode("overwrite")
.save(dir.getCanonicalPath)
val df = spark.read.load(dir.getCanonicalPath)
.filter("size(c1) > 5 and (element_at(c1, 7) = 8 or element_at(c1, 7) = 7)")
checkAnswer(
df, Row((1 to 10).toArray) :: Nil
)
}
}
}
}
}

case class GroupByKey(a: Int, b: Int)
Expand Down