Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,27 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case _ => false
}

/**
* Condition for redundant null check based on intolerant expressions.
* @param ifNullExpr expression that takes place if checkedExpr is null
* @param ifNotNullExpr expression that takes place if checkedExpr is not null
* @param checkedExpr expression that is checked for null value
*/
private def isRedundantNullCheck(
ifNullExpr: Expression,
ifNotNullExpr: Expression,
checkedExpr: Expression): Boolean = {
val isNullIntolerant = ifNotNullExpr.find { x =>
!x.isInstanceOf[NullIntolerant] && x.find(e => e.semanticEquals(checkedExpr)).nonEmpty
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually i think we need slightly different logic. Consider these two examples where x will be the null-checked column:

  1. substring(x, coalesce(a, b), c)

  2. substring(coalesce(x, d), a, c)

For 1. we need to be null-intolerant (even though coalesce is null-tolerant), so if x is null, we replace the substring with null value no matter what are the other children. For 2. we need to be null-tolerant and we will not replace the substring by null value. So we need to check the expression with respect to the position of x (the column that is being null-checked). Does it make sense?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, you meant FiterExec.isNullIntolerant(ifNotNullExpr) || additional checks for the case having null-tolerant exprs inside ifNotNullExpr? (FiterExec.isNullIntolerant is private though...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, FilterExec.isNullIntolerant(ifNotNullExpr) is a stronger condition than we need so in case there is null-tolerant expr inside we need to check if the null-checked column is in its subtree. Using the logic from FilterExec.isNullIntolerant the function could look like this:

def isNullIntolerant(expr: Expression): Boolean = expr match {
  case e: NullIntolerant => e.children.forall(isNullIntolerant)
  case e if e.find(x => x.semanticEquals(checkedExpr)).isEmpty => true
  case _ => false
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. For better code readability, could you split the condition into the two parts as I suggested above? Also, I think its better to leave some comments about why we need more checks there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the committed code is not very intuitive so i can think of this way which seems to be more readable (added also some comments):

private def isRedundantNullCheck(
    ifNullExpr: Expression,
    ifNotNullExpr: Expression,
    checkedExpr: Expression): Boolean = {
  
  // checks if expr is null-intolerant with respect to checkedExpr
  def isNullIntolerant(expr: Expression): Boolean = expr match {
    case e: NullIntolerant => e.children.forall(isNullIntolerant)
    // if some child is null-tolerant but the checkedEpxr is not in its subtree
    // we can still consider the whole expr as null-intolerant 
    // with respect to checkedExpr
    case e if e.find(x => x.semanticEquals(checkedExpr)).isEmpty => true
    case _ => false
  }
  
  isNullIntolerant(ifNotNullExpr) && {
    (ifNullExpr.semanticEquals(checkedExpr) ||
      (ifNullExpr.foldable && ifNullExpr.eval() == null)) &&
      // we still need to make sure that checkedExpr is inside ifNotNullExpr
      ifNotNullExpr.find(x => x.semanticEquals(checkedExpr)).nonEmpty
  }
}

But not sure if this is what you had in mind when suggesting to split the condition. Can you think of a better way how to compose this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, that still looks complicated.. If we cannot avoid the complexity for the stronger condition, as another option, I think we can cover the simple case (FiterExec.isNullIntolerant(ifNotNullExpr)) only in this pr. If necessary, we might be able to optimize the condition in future work. I think keeping the code simple is more important. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the think is that if we use the simple version with FilterExec.isNullIntolerant(ifNutNullExpr) we will loose (because of the recursive check) all expressions that contain literals (because literals are null-tolerant), so for example expressions like this substring(title#5, 0, 3) will not be included in the optimization (which the jira was targeted for in the first place). So I suggest one of these 2 options:

  1. Use the complex version of the code and thus include more expressions in the optimization
  2. Have the code more simple and use the original version before the generalization step, i.e.
private def isRedundantNullCheck(
    ifNullExpr: Expression,
    ifNotNullExpr: Expression,
    checkedExpr: Expression): Boolean = {
  ifNotNullExpr.isInstanceOf[NullIntolerant] && {
    (ifNullExpr == checkedExpr || ifNullExpr == Literal.create(null, checkedExpr.dataType)) &&
      ifNotNullExpr.children.contains(checkedExpr)
  }
}

where checkedExpr must be direct child and thus we don't have to check the whole subtree for null-intolerance (so expressions that have Literals in the subtree are still included).
I am fine with either of these 2 options. What do you think?

}.isEmpty

isNullIntolerant && {
(ifNullExpr.semanticEquals(checkedExpr) ||
(ifNullExpr.foldable && ifNullExpr.eval() == null)) &&
ifNotNullExpr.find(x => x.semanticEquals(checkedExpr)).nonEmpty
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: How about this style?

  private def isRedundantNullCheck(
      ifNullExpr: Expression,
      ifNotNullExpr: Expression,
      checkedExpr: Expression): Boolean = {
    ifNotNullExpr.isInstanceOf[NullIntolerant] && {
      (ifNullExpr == checkedExpr || ifNullExpr == Literal.create(null, checkedExpr.dataType)) &&
        ifNotNullExpr.children.contains(checkedExpr)
    }
  }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first condition ifNullExpr == checkedExpr -> ifNullExpr.semanticEquals(checkedExpr)? e.g., if isnull(a + b) b + a else xxx

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second condition ifNullExpr == Literal.create(null, checkedExpr.dataType) -> ifNullExpr.foldable && ifNullExpr.eval() == null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you generalize the last condition more? e.g., how about the case, substring(other_func(title#5), 0, 3) in the example you described?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that should be possible.


def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
Expand All @@ -442,6 +463,15 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
case If(cond, trueValue, falseValue)
if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue

case i @ If(cond, trueValue, falseValue) => cond match {
// If the null-check is redundant, remove it
case IsNull(child)
if isRedundantNullCheck(trueValue, falseValue, child) => falseValue
case IsNotNull(child)
if isRedundantNullCheck(falseValue, trueValue, child) => trueValue
case _ => i
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this format?;

      // If the null-check is redundant, remove it
      case If(IsNull(child), trueValue, falseValue)
        if isRedundantNullCheck(trueValue, falseValue, child) => falseValue
      case If(IsNotNull(child), trueValue, falseValue)
        if isRedundantNullCheck(falseValue, trueValue, child) => trueValue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you add the inner pattern-matching (cond match { )? I think its better to avoid unnecessary pattern matching (In the current fix, all the cases for If exprs can be matched in the line 466).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. You are right, i do not need the inner pattern match, i will fix that.


case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
// If there are no more branches left, just use the else value.
Expand Down Expand Up @@ -483,6 +513,21 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
} else {
e.copy(branches = branches.take(i).map(branch => (branch._1, elseValue)))
}

// remove redundant null checks for CaseWhen with one branch
case CaseWhen(Seq((IsNotNull(child), trueValue)), Some(falseValue))
if isRedundantNullCheck(falseValue, trueValue, child) => trueValue
case CaseWhen(Seq((IsNull(child), trueValue)), Some(falseValue))
if isRedundantNullCheck(trueValue, falseValue, child) => falseValue
case CaseWhen(Seq((IsNotNull(child), trueValue)), None)
if isRedundantNullCheck(Literal.create(null, child.dataType), trueValue, child) => trueValue
case e @ CaseWhen(Seq((IsNull(child), trueValue)), None) =>
val nullValue = Literal.create(null, child.dataType)
if (isRedundantNullCheck(trueValue, nullValue, child)) {
nullValue
} else {
e
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this?

      // remove redundant null checks for CaseWhen with one branch
      case CaseWhen(Seq((IsNotNull(child), trueValue)), Some(falseValue))
        if isRedundantNullCheck(falseValue, trueValue, child) => trueValue
      case CaseWhen(Seq((IsNull(child), trueValue)), Some(falseValue))
        if isRedundantNullCheck(trueValue, falseValue, child) => falseValue
      case CaseWhen(Seq((IsNotNull(child), trueValue)), None)
        if isRedundantNullCheck(Literal.create(null, child.dataType), trueValue, child) => trueValue
      case e @ CaseWhen(Seq((IsNull(child), trueValue)), None) =>
        val nullValue = Literal.create(null, child.dataType)
        if (isRedundantNullCheck(trueValue, nullValue, child)) {
          nullValue
        } else {
          e
        }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright

}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{IntegerType, NullType}


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You need to avoid unnecessary changes like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

Expand All @@ -35,8 +36,8 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
}

protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
val correctAnswer = Project(Alias(e2, "out")() :: Nil, LocalRelation('a.int)).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, LocalRelation('a.int)).analyze)
comparePlans(actual, correctAnswer)
}

Expand All @@ -45,7 +46,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
private val unreachableBranch = (FalseLiteral, Literal(20))
private val nullBranch = (Literal.create(null, NullType), Literal(30))

val isNotNullCond = IsNotNull(UnresolvedAttribute(Seq("a")))
private val nullValue = Literal.create(null, IntegerType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you change from NullType to IntegerType here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need the same dataType as i have for the a attribute. But i can just add another nullValue to the test and keep the previous with the original dataType.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I think its better to avoid the behaviour changes in the existing tests.

private val colA = UnresolvedAttribute(Seq("a"))
private val nullIntolerantExp = Abs(colA)
private val nullTolerantExp = Coalesce(Seq(colA, Literal(5)))

val isNullCondA = IsNull(colA)
val isNotNullCond = IsNotNull(colA)
val isNullCond = IsNull(UnresolvedAttribute("b"))
val notCond = Not(UnresolvedAttribute("c"))

Expand Down Expand Up @@ -80,6 +87,74 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
Literal(9)))
}

test("remove redundant null-check for If based on null-Intolerant expressions") {
assertEquivalent(
If(isNullCondA, nullValue, nullIntolerantExp),
nullIntolerantExp)

assertEquivalent(
If(isNullCondA, colA, nullIntolerantExp),
nullIntolerantExp)

assertEquivalent(
If(isNotNullCond, nullIntolerantExp, nullValue),
nullIntolerantExp)

assertEquivalent(
If(isNotNullCond, nullIntolerantExp, colA),
nullIntolerantExp)

// Try also more complex case
assertEquivalent(
If(isNotNullCond, Abs(nullIntolerantExp), colA),
Abs(nullIntolerantExp))

// We do not remove the null check if the expression is not null-intolerant
assertEquivalent(
If(isNullCondA, nullValue, nullTolerantExp),
If(isNullCondA, nullValue, nullTolerantExp))

assertEquivalent(
If(isNotNullCond, nullTolerantExp, nullValue),
If(isNotNullCond, nullTolerantExp, nullValue))

// Try also more complex case
assertEquivalent(
If(isNotNullCond, Abs(nullTolerantExp), nullValue),
If(isNotNullCond, Abs(nullTolerantExp), nullValue))
}

test("remove redundant null-check for CaseWhen based on null-Intolerant expressions") {
assertEquivalent(
CaseWhen(Seq((isNullCondA, nullValue)), Some(nullIntolerantExp)),
nullIntolerantExp)

assertEquivalent(
CaseWhen(Seq((isNullCondA, colA)), Some(nullIntolerantExp)),
nullIntolerantExp)

assertEquivalent(
CaseWhen(Seq((isNotNullCond, nullIntolerantExp))),
nullIntolerantExp)

assertEquivalent(
CaseWhen(Seq((isNotNullCond, nullIntolerantExp)), Some(colA)),
nullIntolerantExp)

assertEquivalent(
CaseWhen(Seq((isNotNullCond, nullIntolerantExp)), Some(nullValue)),
nullIntolerantExp)

// We do not remove the null check if the expression is not null-intolerant
assertEquivalent(
CaseWhen(Seq((isNotNullCond, nullTolerantExp))),
CaseWhen(Seq((isNotNullCond, nullTolerantExp))))

assertEquivalent(
CaseWhen(Seq((isNullCondA, nullValue)), Some(nullTolerantExp)),
CaseWhen(Seq((isNullCondA, nullValue)), Some(nullTolerantExp)))
}

test("remove unreachable branches") {
// i.e. removing branches whose conditions are always false
assertEquivalent(
Expand Down