Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,15 @@ package object expressions {
* constraints.
*/
trait NullIntolerant

/**
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match {
case a: Attribute => Seq(a)
case _: NullIntolerant | IsNotNull(_: NullIntolerant) =>
expr.children.flatMap(scanNullIntolerantExpr)
case _ => Seq.empty[Attribute]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
Expand Down Expand Up @@ -564,7 +565,24 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe
val newFilters = filter.constraints --
(child.constraints ++ splitConjunctivePredicates(condition))
if (newFilters.nonEmpty) {
Filter(And(newFilters.reduce(And), condition), child)
// Get the attributes which are not null given the IsNotNull conditions.
val isNotNullAttrs = condition.collect {
case c: IsNotNull => c
}.flatMap(expressions.scanNullIntolerantExpr(_))
val dedupFilters = newFilters.filter { cond =>
cond match {
// If the newly added IsNotNull condition is guaranteed given current condition,
// we don't need to include it.
case IsNotNull(a: Attribute) if isNotNullAttrs.contains(a) => false
case _ => true
}
}
val newCondition = if (dedupFilters.isEmpty) {
condition
} else {
And(dedupFilters.reduce(And), condition)
}
Filter(newCondition, child)
} else {
filter
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.plans

import org.apache.spark.sql.catalyst.{expressions => CatalystExpression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.{DataType, StructType}
Expand Down Expand Up @@ -47,7 +48,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
// First, we propagate constraints from the null intolerant expressions.
var isNotNullConstraints: Set[Expression] =
constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_))
constraints.flatMap(CatalystExpression.scanNullIntolerantExpr).map(IsNotNull(_))

// Second, we infer additional constraints from non-nullable attributes that are part of the
// operator's output
Expand All @@ -57,17 +58,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
isNotNullConstraints -- constraints
}

/**
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match {
case a: Attribute => Seq(a)
case _: NullIntolerant | IsNotNull(_: NullIntolerant) =>
expr.children.flatMap(scanNullIntolerantExpr)
case _ => Seq.empty[Attribute]
}

// Collect aliases from expressions, so we may avoid producing recursive constraints.
private lazy val aliasMap = AttributeMap(
(expressions ++ children.flatMap(_.expressions)).collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("no redundant isnotnull condition inferred from constraints") {
val originalQuery = testRelation.where('a === 1 && IsNotNull('a + 2)).analyze
// Make sure isnotnull('a) is in the constraints.
val isNotNullForA = originalQuery.constraints.find { c =>
c match {
case IsNotNull(a: Attribute) if a.semanticEquals(originalQuery.output(0)) => true
case _ => false
}
}
assert(isNotNullForA.isDefined)
// We don't need to add another isnotnull('a) although it is in the constraints.
val correctAnswer = originalQuery
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}

test("inner join with alias: alias contains multiple attributes") {
val t1 = testRelation.subquery('t1)
val t2 = testRelation.subquery('t2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.scalatest.Matchers._

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, Union}
import org.apache.spark.sql.catalyst.expressions.{Expression, IsNotNull}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
Expand Down Expand Up @@ -1635,6 +1636,51 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}

test("Skip additinal isnotnull condition inferred from constraints if possibly") {
def getFilter(df: DataFrame): Filter = {
df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}.head.asInstanceOf[Filter]
}
def getIsNotNull(expr: Expression): Seq[Expression] = {
expr collect {
case e: IsNotNull => e
}
}
val df = sparkContext.parallelize(Seq(
null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF()

// The following isnotnull conditions don't need additional isnotnull condition from
// the constraints.
val expr1 = "_2 + Rand()"
val filter1 = getFilter(df.where(s"isnotnull($expr1)"))
assert(getIsNotNull(filter1.condition).size == 1)

val expr2 = "_2 + coalesce(_1, 0)"
val filter2 = getFilter(df.where(s"isnotnull($expr2)"))
assert(getIsNotNull(filter2.condition).size == 1)

val expr3 = "_1 + _2 * 3"
val filter3 = getFilter(df.where(s"isnotnull($expr3)"))
assert(getIsNotNull(filter3.condition).size == 1)

val expr4 = "_1"
val filter4 = getFilter(df.where(s"isnotnull($expr4)"))
assert(getIsNotNull(filter4.condition).size == 1)

val expr5 = "cast((_1 + _2) as boolean)"
val filter5 = getFilter(df.where(s"isnotnull($expr5)"))
assert(getIsNotNull(filter5.condition).size == 1)

// For the condition: _2 > 1, we would add additional condition IsNotNull(_2).
// This null check will be put ahead of _2 > 1 in physical Filter op so we can skip it early.
val expr6 = "_2 > 1"
val filter6 = getFilter(df.where(s"$expr6"))
assert(getIsNotNull(filter6.condition).size == 1)
}

test("SPARK-17123: Performing set operations that combine non-scala native types") {
val dates = Seq(
(new Date(0), BigDecimal.valueOf(1), new Timestamp(2)),
Expand Down