diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1510a4796683..383de28d2822 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -131,4 +131,15 @@ package object expressions { * constraints. */ trait NullIntolerant + + /** + * Recursively explores the expressions which are null intolerant and returns all attributes + * in these expressions. + */ + def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match { + case a: Attribute => Seq(a) + case _: NullIntolerant | IsNotNull(_: NullIntolerant) => + expr.children.flatMap(scanNullIntolerantExpr) + case _ => Seq.empty[Attribute] + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e5e2cd7d27d1..3eaf421b41d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} @@ -564,7 +565,24 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] with PredicateHelpe val newFilters = filter.constraints -- (child.constraints ++ splitConjunctivePredicates(condition)) if (newFilters.nonEmpty) { - Filter(And(newFilters.reduce(And), condition), child) + // Get the attributes which are not null given the IsNotNull conditions. + val isNotNullAttrs = condition.collect { + case c: IsNotNull => c + }.flatMap(expressions.scanNullIntolerantExpr(_)) + val dedupFilters = newFilters.filter { cond => + cond match { + // If the newly added IsNotNull condition is guaranteed given current condition, + // we don't need to include it. + case IsNotNull(a: Attribute) if isNotNullAttrs.contains(a) => false + case _ => true + } + } + val newCondition = if (dedupFilters.isEmpty) { + condition + } else { + And(dedupFilters.reduce(And), condition) + } + Filter(newCondition, child) } else { filter } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 45ee2964d4db..d5efdd15e7fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans +import org.apache.spark.sql.catalyst.{expressions => CatalystExpression} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} @@ -47,7 +48,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { // First, we propagate constraints from the null intolerant expressions. var isNotNullConstraints: Set[Expression] = - constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_)) + constraints.flatMap(CatalystExpression.scanNullIntolerantExpr).map(IsNotNull(_)) // Second, we infer additional constraints from non-nullable attributes that are part of the // operator's output @@ -57,17 +58,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT isNotNullConstraints -- constraints } - /** - * Recursively explores the expressions which are null intolerant and returns all attributes - * in these expressions. - */ - private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match { - case a: Attribute => Seq(a) - case _: NullIntolerant | IsNotNull(_: NullIntolerant) => - expr.children.flatMap(scanNullIntolerantExpr) - case _ => Seq.empty[Attribute] - } - // Collect aliases from expressions, so we may avoid producing recursive constraints. private lazy val aliasMap = AttributeMap( (expressions ++ children.flatMap(_.expressions)).collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 9f57f66a2ea2..13f6bf793a26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -124,6 +124,22 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("no redundant isnotnull condition inferred from constraints") { + val originalQuery = testRelation.where('a === 1 && IsNotNull('a + 2)).analyze + // Make sure isnotnull('a) is in the constraints. + val isNotNullForA = originalQuery.constraints.find { c => + c match { + case IsNotNull(a: Attribute) if a.semanticEquals(originalQuery.output(0)) => true + case _ => false + } + } + assert(isNotNullForA.isDefined) + // We don't need to add another isnotnull('a) although it is in the constraints. + val correctAnswer = originalQuery + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + test("inner join with alias: alias contains multiple attributes") { val t1 = testRelation.subquery('t1) val t2 = testRelation.subquery('t2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 33b3b78c9f04..5f03ac2b7f63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -28,7 +28,8 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, Union} +import org.apache.spark.sql.catalyst.expressions.{Expression, IsNotNull} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} @@ -1635,6 +1636,51 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("Skip additinal isnotnull condition inferred from constraints if possibly") { + def getFilter(df: DataFrame): Filter = { + df.queryExecution.optimizedPlan.collect { + case f: Filter => f + }.head.asInstanceOf[Filter] + } + def getIsNotNull(expr: Expression): Seq[Expression] = { + expr collect { + case e: IsNotNull => e + } + } + val df = sparkContext.parallelize(Seq( + null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3), + new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer], + new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF() + + // The following isnotnull conditions don't need additional isnotnull condition from + // the constraints. + val expr1 = "_2 + Rand()" + val filter1 = getFilter(df.where(s"isnotnull($expr1)")) + assert(getIsNotNull(filter1.condition).size == 1) + + val expr2 = "_2 + coalesce(_1, 0)" + val filter2 = getFilter(df.where(s"isnotnull($expr2)")) + assert(getIsNotNull(filter2.condition).size == 1) + + val expr3 = "_1 + _2 * 3" + val filter3 = getFilter(df.where(s"isnotnull($expr3)")) + assert(getIsNotNull(filter3.condition).size == 1) + + val expr4 = "_1" + val filter4 = getFilter(df.where(s"isnotnull($expr4)")) + assert(getIsNotNull(filter4.condition).size == 1) + + val expr5 = "cast((_1 + _2) as boolean)" + val filter5 = getFilter(df.where(s"isnotnull($expr5)")) + assert(getIsNotNull(filter5.condition).size == 1) + + // For the condition: _2 > 1, we would add additional condition IsNotNull(_2). + // This null check will be put ahead of _2 > 1 in physical Filter op so we can skip it early. + val expr6 = "_2 > 1" + val filter6 = getFilter(df.where(s"$expr6")) + assert(getIsNotNull(filter6.condition).size == 1) + } + test("SPARK-17123: Performing set operations that combine non-scala native types") { val dates = Seq( (new Date(0), BigDecimal.valueOf(1), new Timestamp(2)),