diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 269c8500f95c..99239b549e9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -434,6 +434,27 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case _ => false } + /** + * Condition for redundant null check based on intolerant expressions. + * @param ifNullExpr expression that takes place if checkedExpr is null + * @param ifNotNullExpr expression that takes place if checkedExpr is not null + * @param checkedExpr expression that is checked for null value + */ + private def isRedundantNullCheck( + ifNullExpr: Expression, + ifNotNullExpr: Expression, + checkedExpr: Expression): Boolean = { + val isNullIntolerant = ifNotNullExpr.find { x => + !x.isInstanceOf[NullIntolerant] && x.find(e => e.semanticEquals(checkedExpr)).nonEmpty + }.isEmpty + + isNullIntolerant && { + (ifNullExpr.semanticEquals(checkedExpr) || + (ifNullExpr.foldable && ifNullExpr.eval() == null)) && + ifNotNullExpr.find(x => x.semanticEquals(checkedExpr)).nonEmpty + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue @@ -442,6 +463,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(cond, trueValue, falseValue) if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue + case i @ If(cond, trueValue, falseValue) => cond match { + // If the null-check is redundant, remove it + case IsNull(child) + if isRedundantNullCheck(trueValue, falseValue, child) => falseValue + case IsNotNull(child) + if isRedundantNullCheck(falseValue, trueValue, child) => trueValue + case _ => i + } + case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. // If there are no more branches left, just use the else value. @@ -483,6 +513,21 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } else { e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue))) } + + // remove redundant null checks for CaseWhen with one branch + case CaseWhen(Seq((IsNotNull(child), trueValue)), Some(falseValue)) + if isRedundantNullCheck(falseValue, trueValue, child) => trueValue + case CaseWhen(Seq((IsNull(child), trueValue)), Some(falseValue)) + if isRedundantNullCheck(trueValue, falseValue, child) => falseValue + case CaseWhen(Seq((IsNotNull(child), trueValue)), None) + if isRedundantNullCheck(Literal.create(null, child.dataType), trueValue, child) => trueValue + case e @ CaseWhen(Seq((IsNull(child), trueValue)), None) => + val nullValue = Literal.create(null, child.dataType) + if (isRedundantNullCheck(trueValue, nullValue, child)) { + nullValue + } else { + e + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 8ad7c12020b8..a6482591c2be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.{IntegerType, NullType} @@ -35,8 +36,8 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { - val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze - val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze) + val correctAnswer = Project(Alias(e2, "out")() :: Nil, LocalRelation('a.int)).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, LocalRelation('a.int)).analyze) comparePlans(actual, correctAnswer) } @@ -45,7 +46,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val unreachableBranch = (FalseLiteral, Literal(20)) private val nullBranch = (Literal.create(null, NullType), Literal(30)) - val isNotNullCond = IsNotNull(UnresolvedAttribute(Seq("a"))) + private val nullValue = Literal.create(null, IntegerType) + private val colA = UnresolvedAttribute(Seq("a")) + private val nullIntolerantExp = Abs(colA) + private val nullTolerantExp = Coalesce(Seq(colA, Literal(5))) + + val isNullCondA = IsNull(colA) + val isNotNullCond = IsNotNull(colA) val isNullCond = IsNull(UnresolvedAttribute("b")) val notCond = Not(UnresolvedAttribute("c")) @@ -80,6 +87,74 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { Literal(9))) } + test("remove redundant null-check for If based on null-Intolerant expressions") { + assertEquivalent( + If(isNullCondA, nullValue, nullIntolerantExp), + nullIntolerantExp) + + assertEquivalent( + If(isNullCondA, colA, nullIntolerantExp), + nullIntolerantExp) + + assertEquivalent( + If(isNotNullCond, nullIntolerantExp, nullValue), + nullIntolerantExp) + + assertEquivalent( + If(isNotNullCond, nullIntolerantExp, colA), + nullIntolerantExp) + + // Try also more complex case + assertEquivalent( + If(isNotNullCond, Abs(nullIntolerantExp), colA), + Abs(nullIntolerantExp)) + + // We do not remove the null check if the expression is not null-intolerant + assertEquivalent( + If(isNullCondA, nullValue, nullTolerantExp), + If(isNullCondA, nullValue, nullTolerantExp)) + + assertEquivalent( + If(isNotNullCond, nullTolerantExp, nullValue), + If(isNotNullCond, nullTolerantExp, nullValue)) + + // Try also more complex case + assertEquivalent( + If(isNotNullCond, Abs(nullTolerantExp), nullValue), + If(isNotNullCond, Abs(nullTolerantExp), nullValue)) + } + + test("remove redundant null-check for CaseWhen based on null-Intolerant expressions") { + assertEquivalent( + CaseWhen(Seq((isNullCondA, nullValue)), Some(nullIntolerantExp)), + nullIntolerantExp) + + assertEquivalent( + CaseWhen(Seq((isNullCondA, colA)), Some(nullIntolerantExp)), + nullIntolerantExp) + + assertEquivalent( + CaseWhen(Seq((isNotNullCond, nullIntolerantExp))), + nullIntolerantExp) + + assertEquivalent( + CaseWhen(Seq((isNotNullCond, nullIntolerantExp)), Some(colA)), + nullIntolerantExp) + + assertEquivalent( + CaseWhen(Seq((isNotNullCond, nullIntolerantExp)), Some(nullValue)), + nullIntolerantExp) + + // We do not remove the null check if the expression is not null-intolerant + assertEquivalent( + CaseWhen(Seq((isNotNullCond, nullTolerantExp))), + CaseWhen(Seq((isNotNullCond, nullTolerantExp)))) + + assertEquivalent( + CaseWhen(Seq((isNullCondA, nullValue)), Some(nullTolerantExp)), + CaseWhen(Seq((isNullCondA, nullValue)), Some(nullTolerantExp))) + } + test("remove unreachable branches") { // i.e. removing branches whose conditions are always false assertEquivalent(