diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 06985ac85b70..47421090c3fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -119,6 +119,49 @@ trait PredicateHelper { case e: Unevaluable => false case e => e.children.forall(canEvaluateWithinJoin) } + + /** + * Given an IsNotNull expression, returns the IDs of expressions whose not-nullness + * is implied by the IsNotNull expressions. + */ + protected def getImpliedNotNullExprIds(isNotNullExpr: IsNotNull): Set[ExprId] = { + // This logic is a little tricky, so we'll use an example to build some intuition. + // Consider the expression IsNotNull(f(g(x), y)). By definition, its child is not null: + // f(g(x), y) is not null + // In addition, if `f` is NullIntolerant then it would be null if either child was null: + // g(x) is null => f(g(x), y) is null + // y is null => f(g(x), y) is null + // Via A => B <=> !B || A, we have: + // g(x) is not null || f(g(x), y) is null + // y is not null || f(g(x), y) is null + // Since we know that f(g(x), y) is not null, we must therefore conclude that + // g(x) is not null + // y is not null + // By recursively applying this logic, if g is NullIntolerant then x is not null. + // However, if g is NOT NullIntolerant (e.g. if g(null) is non-null) then we cannot + // conclude anything about x's nullability. + def getExprIdIfNamed(expr: Expression): Set[ExprId] = expr match { + case ne: NamedExpression => Set(ne.toAttribute.exprId) + case _ => Set.empty + } + def isNullIntolerant(expr: Expression): Boolean = expr match { + case _: NullIntolerant => true + case Alias(_: NullIntolerant, _) => true + case _ => false + } + // Recurse through the IsNotNull expression's descendants, stopping + // once we encounter a null-tolerant expression. + def getNotNullDescendants(expr: Expression): Set[ExprId] = { + expr.children.map { child => + if (isNullIntolerant(child)) { + getExprIdIfNamed(child) ++ getNotNullDescendants(child) + } else { + getExprIdIfNamed(child) + } + }.foldLeft(Set.empty[ExprId])(_ ++ _) + } + getExprIdIfNamed(isNotNullExpr) ++ getNotNullDescendants(isNotNullExpr) + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6bf12cff28f9..fabf8e2903bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -51,7 +51,14 @@ case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode { case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends OrderPreservingUnaryNode { - override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override def output: Seq[Attribute] = { + // The child operator may have inferred more precise nullability information + // for the project expression, so leverage that information if it's availble: + val childOutputNullability = child.output.map(a => a.exprId -> a.nullable).toMap + projectList + .map(_.toAttribute) + .map{ a => childOutputNullability.get(a.exprId).map(a.withNullability).getOrElse(a) } + } override def maxRows: Option[Long] = child.maxRows override lazy val resolved: Boolean = { @@ -129,7 +136,22 @@ case class Generate( case class Filter(condition: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode with PredicateHelper { - override def output: Seq[Attribute] = child.output + + override def output: Seq[Attribute] = { + val impliedNotNullExprIds: Set[ExprId] = { + splitConjunctivePredicates(condition) + .collect { case isNotNull: IsNotNull => isNotNull } + .map(getImpliedNotNullExprIds) + .foldLeft(Set.empty[ExprId])(_ ++ _) + } + child.output.map { a => + if (a.nullable && impliedNotNullExprIds.contains(a.exprId)) { + a.withNullability(false) + } else { + a + } + } + } override def maxRows: Option[Long] = child.maxRows diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 49fd59c8694f..794623b6d161 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.types._ -class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { +class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PredicateHelper { def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = { testFunc(false, BooleanType) @@ -175,4 +175,22 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val inputs = (1 to 4000).map(x => Literal(s"x_$x")) checkEvaluation(AtLeastNNonNulls(1, inputs), true) } + + test("getImpliedNotNullExprIds") { + val a = AttributeReference("a", IntegerType)(exprId = ExprId(1)) + val b = AttributeReference("b", IntegerType)(exprId = ExprId(2)) + + // Simple case of IsNotNull of a leaf value: + assert(getImpliedNotNullExprIds(IsNotNull(a)) == Set(a.exprId)) + + // Even though we can't make claims about its children, a non-NullIntolerant is + // expression is still considered non-null due to its parent IsNotNull expression: + val coalesceExpr = Alias(Coalesce(Seq(a, b)), "c")(exprId = ExprId(3)) + assert(getImpliedNotNullExprIds(IsNotNull(coalesceExpr)) == Set(coalesceExpr.exprId)) + + // NullIntolerant expressions propagate the non-null constraint to all of their children: + val addExpr = Alias(Add(a, b), "add")(exprId = ExprId(4)) + assert(addExpr.child.isInstanceOf[NullIntolerant]) + assert(getImpliedNotNullExprIds(IsNotNull(addExpr)) == Set(a, b, addExpr).map(_.exprId)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 7204548181f6..12e7574e7728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -89,26 +89,23 @@ case class FilterExec(condition: Expression, child: SparkPlan) // Split out all the IsNotNulls from condition. private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { - case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet) + case IsNotNull(_) => true case _ => false } - // If one expression and its children are null intolerant, it is null intolerant. - private def isNullIntolerant(expr: Expression): Boolean = expr match { - case e: NullIntolerant => e.children.forall(isNullIntolerant) - case _ => false + private val impliedNotNullExprIds: Set[ExprId] = { + notNullPreds + .map { case n: IsNotNull => getImpliedNotNullExprIds(n) } + .foldLeft(Set.empty[ExprId])(_ ++ _) } - // The columns that will filtered out by `IsNotNull` could be considered as not nullable. - private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) - // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate // all the variables at the beginning to take advantage of short circuiting. override def usedInputs: AttributeSet = AttributeSet.empty override def output: Seq[Attribute] = { child.output.map { a => - if (a.nullable && notNullAttributes.contains(a.exprId)) { + if (a.nullable && impliedNotNullExprIds.contains(a.exprId)) { a.withNullability(false) } else { a @@ -193,7 +190,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) // Reset the isNull to false for the not-null columns, then the followed operators could // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => - if (notNullAttributes.contains(child.output(i).exprId)) { + if (impliedNotNullExprIds.contains(child.output(i).exprId)) { ev.isNull = FalseLiteral } ev