diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b7884f9b60f3..347bd9619cf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification -import org.apache.spark.sql.catalyst.planning.IntegerIndex +import org.apache.spark.sql.catalyst.planning.{ExtractJoinOutputAttributes, IntegerIndex} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ @@ -106,6 +106,8 @@ class Analyzer( TimeWindowing :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), + Batch("Solve", Once, + SolveIllegalReferences), Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, @@ -1442,6 +1444,32 @@ class Analyzer( } } + /** + * Corrects attribute references in an expression tree of some operators (e.g., filters and + * projects) if these operators have a join as a child and the references point to columns on the + * input relation of the join. This is because some joins change the nullability of input columns + * and this could cause illegal optimization (e.g., NULL propagation) and wrong answers. + * See SPARK-13484 and SPARK-13801 for the concrete queries of this case. + */ + object SolveIllegalReferences extends Rule[LogicalPlan] { + + private def replaceReferences(e: Expression, attrMap: AttributeMap[Attribute]) = e.transform { + case a: AttributeReference => attrMap.get(a).getOrElse(a) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case q: LogicalPlan => + q.transform { + case f @ Filter(filterCondition, ExtractJoinOutputAttributes(join, joinOutputMap)) => + f.copy(condition = replaceReferences(filterCondition, joinOutputMap)) + case p @ Project(projectList, ExtractJoinOutputAttributes(join, joinOutputMap)) => + p.copy(projectList = projectList.map { e => + replaceReferences(e, joinOutputMap).asInstanceOf[NamedExpression] + }) + } + } + } + /** * Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and * aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]] @@ -2122,4 +2150,3 @@ object TimeWindowing extends Rule[LogicalPlan] { } } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 00656191354f..d617c41408b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -181,6 +181,23 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { } } +/** + * An extractor for join output attributes directly under a given operator. + */ +object ExtractJoinOutputAttributes { + + def unapply(plan: LogicalPlan): Option[(Join, AttributeMap[Attribute])] = { + plan.collectFirst { + case j: Join => j + }.map { join => + val joinOutput = new mutable.ArrayBuffer[(Attribute, Attribute)] + join.output.foreach { + case a: AttributeReference => joinOutput += ((a, a)) + } + (join, AttributeMap(joinOutput)) + } + } +} /** * A pattern that collects all adjacent unions and returns their children as a Seq. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index 1423a8705af2..748579df4158 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -100,7 +100,7 @@ class ResolveNaturalJoinSuite extends AnalysisTest { val naturalPlan = r3.join(r4, NaturalJoin(FullOuter), None) val usingPlan = r3.join(r4, UsingJoin(FullOuter, Seq(UnresolvedAttribute("b"))), None) val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select( - Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c) + Alias(Coalesce(Seq(b, b)), "b")(), a, c) checkAnalysis(naturalPlan, expected) checkAnalysis(usingPlan, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 031e66b57cbc..4342c039aefc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -204,4 +204,25 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { leftJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil) } + + test("process outer join results using the non-nullable columns in the join input") { + // Filter data using a non-nullable column from a right table + val df1 = Seq((0, 0), (1, 0), (2, 0), (3, 0), (4, 0)).toDF("id", "count") + val df2 = Seq(Tuple1(0), Tuple1(1)).toDF("id").groupBy("id").count + checkAnswer( + df1.join(df2, df1("id") === df2("id"), "left_outer").filter(df2("count").isNull), + Row(2, 0, null, null) :: + Row(3, 0, null, null) :: + Row(4, 0, null, null) :: Nil + ) + + // Coalesce data using non-nullable columns in input tables + val df3 = Seq((1, 1)).toDF("a", "b") + val df4 = Seq((2, 2)).toDF("a", "b") + checkAnswer( + df3.join(df4, df3("a") === df4("a"), "outer") + .select(coalesce(df3("a"), df3("b")), coalesce(df4("a"), df4("b"))), + Row(1, null) :: Row(null, 2) :: Nil + ) + } }