Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.planning.IntegerIndex
import org.apache.spark.sql.catalyst.planning.{ExtractJoinOutputAttributes, IntegerIndex}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -106,6 +106,8 @@ class Analyzer(
TimeWindowing ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Solve", Once,
SolveIllegalReferences),
Batch("Nondeterministic", Once,
PullOutNondeterministic),
Batch("UDF", Once,
Expand Down Expand Up @@ -1442,6 +1444,32 @@ class Analyzer(
}
}

/**
* Corrects attribute references in an expression tree of some operators (e.g., filters and
* projects) if these operators have a join as a child and the references point to columns on the
* input relation of the join. This is because some joins change the nullability of input columns
* and this could cause illegal optimization (e.g., NULL propagation) and wrong answers.
* See SPARK-13484 and SPARK-13801 for the concrete queries of this case.
*/
object SolveIllegalReferences extends Rule[LogicalPlan] {

private def replaceReferences(e: Expression, attrMap: AttributeMap[Attribute]) = e.transform {
case a: AttributeReference => attrMap.get(a).getOrElse(a)
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case q: LogicalPlan =>
q.transform {
case f @ Filter(filterCondition, ExtractJoinOutputAttributes(join, joinOutputMap)) =>
f.copy(condition = replaceReferences(filterCondition, joinOutputMap))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we use a q.transformUp to fix the nullability in a bottom-up way? For every node, we create an AttributeMap using the output of its child. Then, we use transformExpressions to fix the nullability if necessary. Let me try it out and ping you when I have a version.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I wait your ping.

Copy link
Contributor

@yhuai yhuai May 25, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/apache/spark/pull/13290/files This is the approach that I mentioned above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I'll check it.

case p @ Project(projectList, ExtractJoinOutputAttributes(join, joinOutputMap)) =>
p.copy(projectList = projectList.map { e =>
replaceReferences(e, joinOutputMap).asInstanceOf[NamedExpression]
})
}
}
}

/**
* Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and
* aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]]
Expand Down Expand Up @@ -2122,4 +2150,3 @@ object TimeWindowing extends Rule[LogicalPlan] {
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,23 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
}
}

/**
* An extractor for join output attributes directly under a given operator.
*/
object ExtractJoinOutputAttributes {

def unapply(plan: LogicalPlan): Option[(Join, AttributeMap[Attribute])] = {
plan.collectFirst {
case j: Join => j
}.map { join =>
val joinOutput = new mutable.ArrayBuffer[(Attribute, Attribute)]
join.output.foreach {
case a: AttributeReference => joinOutput += ((a, a))
}
(join, AttributeMap(joinOutput))
}
}
}

/**
* A pattern that collects all adjacent unions and returns their children as a Seq.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class ResolveNaturalJoinSuite extends AnalysisTest {
val naturalPlan = r3.join(r4, NaturalJoin(FullOuter), None)
val usingPlan = r3.join(r4, UsingJoin(FullOuter, Seq(UnresolvedAttribute("b"))), None)
val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select(
Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c)
Alias(Coalesce(Seq(b, b)), "b")(), a, c)
checkAnalysis(naturalPlan, expected)
checkAnalysis(usingPlan, expected)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,25 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
leftJoin2Inner,
Row(1, 2, "1", 1, 3, "1") :: Nil)
}

test("process outer join results using the non-nullable columns in the join input") {
// Filter data using a non-nullable column from a right table
val df1 = Seq((0, 0), (1, 0), (2, 0), (3, 0), (4, 0)).toDF("id", "count")
val df2 = Seq(Tuple1(0), Tuple1(1)).toDF("id").groupBy("id").count
checkAnswer(
df1.join(df2, df1("id") === df2("id"), "left_outer").filter(df2("count").isNull),
Row(2, 0, null, null) ::
Row(3, 0, null, null) ::
Row(4, 0, null, null) :: Nil
)

// Coalesce data using non-nullable columns in input tables
val df3 = Seq((1, 1)).toDF("a", "b")
val df4 = Seq((2, 2)).toDF("a", "b")
checkAnswer(
df3.join(df4, df3("a") === df4("a"), "outer")
.select(coalesce(df3("a"), df3("b")), coalesce(df4("a"), df4("b"))),
Row(1, null) :: Row(null, 2) :: Nil
)
}
}