Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ object SamplePushDown extends Rule[LogicalPlan] {
* Intersect:
* It is not safe to pushdown Projections through it because we need to get the
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
* because we will not have non-deterministic expressions.
* with deterministic condition.
*
* Except:
* It is not safe to pushdown Projections through it because we need to get the
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
* because we will not have non-deterministic expressions.
* with deterministic condition.
*/
object SetOperationPushDown extends Rule[LogicalPlan] {
object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {

/**
* Maps Attributes from the left side to the corresponding Attribute on the right side.
Expand All @@ -129,34 +129,63 @@ object SetOperationPushDown extends Rule[LogicalPlan] {
result.asInstanceOf[A]
}

/**
* Splits the condition expression into small conditions by `And`, and partition them by
* deterministic, and finally recombine them by `And`.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment to explain the meanings of returned Expressions (i.e. the first Expression is for deterministic expressions and the second Expression is for non-deterministic expressions).

private def partitionByDeterministic(condition: Expression): (Expression, Expression) = {
val andConditions = splitConjunctivePredicates(condition)
andConditions.partition(_.deterministic) match {
case (deterministic, nondeterministic) =>
deterministic.reduceOption(And).getOrElse(Literal(true)) ->
nondeterministic.reduceOption(And).getOrElse(Literal(true))
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Push down filter into union
case Filter(condition, u @ Union(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(u)
Union(
Filter(condition, left),
Filter(pushToRight(condition, rewrites), right))

// Push down projection through UNION ALL
case Project(projectList, u @ Union(left, right)) =>
val rewrites = buildRewrites(u)
Union(
Project(projectList, left),
Project(projectList.map(pushToRight(_, rewrites)), right))
Filter(nondeterministic,
Union(
Filter(deterministic, left),
Filter(pushToRight(deterministic, rewrites), right)
)
)

// Push down deterministic projection through UNION ALL
case p @ Project(projectList, u @ Union(left, right)) =>
if (projectList.forall(_.deterministic)) {
val rewrites = buildRewrites(u)
Union(
Project(projectList, left),
Project(projectList.map(pushToRight(_, rewrites)), right))
} else {
p
}

// Push down filter through INTERSECT
case Filter(condition, i @ Intersect(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(i)
Intersect(
Filter(condition, left),
Filter(pushToRight(condition, rewrites), right))
Filter(nondeterministic,
Intersect(
Filter(deterministic, left),
Filter(pushToRight(deterministic, rewrites), right)
)
)

// Push down filter through EXCEPT
case Filter(condition, e @ Except(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
val rewrites = buildRewrites(e)
Except(
Filter(condition, left),
Filter(pushToRight(condition, rewrites), right))
Filter(nondeterministic,
Except(
Filter(deterministic, left),
Filter(pushToRight(deterministic, rewrites), right)
)
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest {
Batch("Subqueries", Once,
EliminateSubQueries) ::
Batch("Union Pushdown", Once,
SetOperationPushDown) :: Nil
SetOperationPushDown,
SimplifyFilters) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
Expand Down
39 changes: 39 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -916,4 +916,43 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(intersect.count() === 30)
assert(except.count() === 70)
}

test("SPARK-10740: handle nondeterministic expressions correctly for set operations") {
val df1 = (1 to 20).map(Tuple1.apply).toDF("i")
val df2 = (1 to 10).map(Tuple1.apply).toDF("i")

def expected(df: DataFrame): Seq[Row] = {
df.rdd.collectPartitions().zipWithIndex.flatMap {
case (data, index) =>
val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be good to have a comment to say we need to match the behavior of RAND when generating expected results.

data.filter(_.getInt(0) < rng.nextDouble() * 10)
}
}

val union = df1.unionAll(df2)
checkAnswer(
union.filter('i < rand(7) * 10),
expected(union)
)
checkAnswer(
union.select(rand(7)),
union.rdd.collectPartitions().zipWithIndex.flatMap {
case (data, index) =>
val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index)
data.map(_ => rng.nextDouble()).map(i => Row(i))
}
)

val intersect = df1.intersect(df2)
checkAnswer(
intersect.filter('i < rand(7) * 10),
expected(intersect)
)

val except = df1.except(df2)
checkAnswer(
except.filter('i < rand(7) * 10),
expected(except)
)
}
}