Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If}
import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, LambdaFunction, Literal, MapFilter, Or}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.BooleanType
Expand Down Expand Up @@ -56,6 +55,12 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
case p: LogicalPlan => p transformExpressions {
// For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no
// difference, as `null <=> true` and `false <=> true` both return false.
case EqualNullSafe(left, TrueLiteral) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change is needed here to not fail the existing tests. When we convert If to EqualNullSafe, we should make sure the rule ReplaceNullWithFalseInPredicate still applies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if someone can pull this out and open a separated PR/JIRA. Otherwise, I'll do it tomorrow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NullPropagation should optimize EqualNullSafe(null, TrueLiteral) to false too, isn't?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can, but not with And, e.g. EqualNullSafe(col && null, true) => EqualNullSafe(col && false, true) => EqualNullSafe(false, true) => false.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. Makes sense.

EqualNullSafe(replaceNullWithFalse(left), TrueLiteral)
case EqualNullSafe(TrueLiteral, right) =>
EqualNullSafe(TrueLiteral, replaceNullWithFalse(right))
case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
case cw @ CaseWhen(branches, _) =>
val newBranches = branches.map { case (cond, value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,10 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
case If(Literal(null, _), _, falseValue) => falseValue
case If(cond, TrueLiteral, FalseLiteral) => cond
case If(cond, FalseLiteral, TrueLiteral) => Not(cond)
case If(cond, TrueLiteral, FalseLiteral) =>
if (cond.nullable) EqualNullSafe(cond, TrueLiteral) else cond
case If(cond, FalseLiteral, TrueLiteral) =>
if (cond.nullable) Not(EqualNullSafe(cond, TrueLiteral)) else Not(cond)
case If(cond, trueValue, falseValue)
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue
case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PushFoldableIntoBranchesSuite

test("Push down EqualTo through If") {
assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral)
assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a))
assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a <=> TrueLiteral))

// Push down at most one not foldable expressions.
assertEquivalent(
Expand Down Expand Up @@ -102,7 +102,7 @@ class PushFoldableIntoBranchesSuite
assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3)))
assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)),
If(a, Literal(2.0), Literal(3.0)))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a))
assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a <=> TrueLiteral))
assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable}
Expand Down Expand Up @@ -237,8 +237,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
TrueLiteral,
FalseLiteral)
val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue))
val expectedCond =
CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> (Literal(2) === nestedCaseWhen)))
val expectedCond = CaseWhen(Seq(
(UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral)))
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
Expand All @@ -253,10 +253,10 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
Literal(3)),
TrueLiteral,
FalseLiteral)
val expectedCond = Literal(5) > If(
val expectedCond = (Literal(5) > If(
UnresolvedAttribute("i") === Literal(15),
Literal(null, IntegerType),
Literal(3))
Literal(3))) <=> TrueLiteral
testFilter(originalCond = condition, expectedCond = expectedCond)
testJoin(originalCond = condition, expectedCond = expectedCond)
testDelete(originalCond = condition, expectedCond = expectedCond)
Expand Down Expand Up @@ -443,9 +443,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
val lambda1 = LambdaFunction(
function = If(cond, Literal(null, BooleanType), TrueLiteral),
arguments = lambdaArgs)
// the optimized lambda body is: if(arg > 0, false, true) => arg <= 0
// the optimized lambda body is: if(arg > 0, false, true) => !((arg > 0) <=> true)
val lambda2 = LambdaFunction(
function = LessThanOrEqual(condArg, Literal(0)),
function = !(cond <=> TrueLiteral),
arguments = lambdaArgs)
testProjection(
originalExpr = createExpr(argument, lambda1) as 'x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,39 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P
}

test("SPARK-33845: remove unnecessary if when the outputs are boolean type") {
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
IsNotNull(UnresolvedAttribute("a")))
assertEquivalent(
If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
IsNull(UnresolvedAttribute("a")))
// verify the boolean equivalence of all transformations involved
val fields = Seq(
'cond.boolean.notNull,
'cond_nullable.boolean,
'a.boolean,
'b.boolean
)
val Seq(cond, cond_nullable, a, b) = fields.zipWithIndex.map { case (f, i) => f.at(i) }

val exprs = Seq(
// actual expressions of the transformations: original -> transformed
If(cond, true, false) -> cond,
If(cond, false, true) -> !cond,
If(cond_nullable, true, false) -> (cond_nullable <=> true),
If(cond_nullable, false, true) -> (!(cond_nullable <=> true)))

// check plans
for ((originalExpr, expectedExpr) <- exprs) {
assertEquivalent(originalExpr, expectedExpr)
}

assertEquivalent(
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral),
GreaterThan(Rand(0), UnresolvedAttribute("a")))
assertEquivalent(
If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral),
LessThanOrEqual(Rand(0), UnresolvedAttribute("a")))
// check evaluation
val binaryBooleanValues = Seq(true, false)
val ternaryBooleanValues = Seq(true, false, null)
for (condVal <- binaryBooleanValues;
condNullableVal <- ternaryBooleanValues;
aVal <- ternaryBooleanValues;
bVal <- ternaryBooleanValues;
(originalExpr, expectedExpr) <- exprs) {
val inputRow = create_row(condVal, condNullableVal, aVal, bVal)
val optimizedVal = evaluateWithoutCodegen(expectedExpr, inputRow)
checkEvaluation(originalExpr, optimizedVal, inputRow)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check result here to make sure there is no correctness issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

}
}

test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {
Expand Down