Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
}

/**
* Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions as
* well as non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
* non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
* returns a constraint of the form `isNotNull(a)`
*/
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
// First, we propagate constraints from the null intolerant expressions.
var isNotNullConstraints: Set[Expression] =
constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_))
var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints)

// Second, we infer additional constraints from non-nullable attributes that are part of the
// operator's output
Expand All @@ -57,14 +56,28 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
isNotNullConstraints -- constraints
}

/**
* Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions
* of constraints.
*/
private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] =
constraint match {
// When the root is IsNotNull, we can push IsNotNull through the child null intolerant
// expressions
case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_))
// Constraints always return true for all the inputs. That means, null will never be returned.
// Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child
// null intolerant expressions.
case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_))
}

/**
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match {
private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match {
case a: Attribute => Seq(a)
case _: NullIntolerant | IsNotNull(_: NullIntolerant) =>
expr.children.flatMap(scanNullIntolerantExpr)
case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute)
case _ => Seq.empty[Attribute]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,15 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(IsNotNull(resolveColumn(tr, "b"))),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "c")))))

verifyConstraints(
tr.where('a.attr === 1 && IsNotNull(resolveColumn(tr, "b")) &&
IsNotNull(resolveColumn(tr, "c"))).analyze.constraints,
ExpressionSet(Seq(
resolveColumn(tr, "a") === 1,
IsNotNull(resolveColumn(tr, "c")),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")))))
}

test("infer IsNotNull constraints from non-nullable attributes") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1649,6 +1649,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
expr = "cast((_1 + _2) as boolean)", expectedNonNullableColumns = Seq("_1", "_2"))
}

test("SPARK-17897: Fixed IsNotNull Constraint Inference Rule") {
val data = Seq[java.lang.Integer](1, null).toDF("key")
checkAnswer(data.filter(!$"key".isNotNull), Row(null))
checkAnswer(data.filter(!(- $"key").isNotNull), Row(null))
}

test("SPARK-17957: outer join + na.fill") {
val df1 = Seq((1, 2), (2, 3)).toDF("a", "b")
val df2 = Seq((2, 5), (3, 4)).toDF("a", "c")
Expand Down