Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,17 @@ object ScalarSubquery {
case _ => false
}.isDefined
}

def hasScalarSubquery(e: Expression): Boolean = {
e.find {
case s: ScalarSubquery => true
case _ => false
}.isDefined
}

def hasScalarSubquery(e: Seq[Expression]): Boolean = {
e.find(hasScalarSubquery(_)).isDefined
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.exists(hasScalarSubquery)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan Sure.

}
}

/**
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
* Returns whether the expression returns null or false when all inputs are nulls.
*/
private def canFilterOutNull(e: Expression): Boolean = {
if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false
if (!e.deterministic ||
SubqueryExpression.hasCorrelatedSubquery(e) ||
SubExprUtils.containsOuter(e)) {
return false
}
val attributes = e.references.toSeq
val emptyRow = new GenericInternalRow(attributes.length)
val boundE = BindReferences.bindReference(e, attributes)
Expand All @@ -147,10 +151,45 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
}
}

private def buildNewJoinType(
upperJoin: Join,
lowerJoin: Join,
otherTableOutput: AttributeSet): JoinType = {
val conditions = upperJoin.constraints
// Find the predicates reference only on the other table.
val localConditions = conditions.filter(_.references.subsetOf(otherTableOutput))
// Find the predicates reference either the left table or the join predicates
// between the left table and the other table.
val leftConditions = conditions.filter(_.references.
subsetOf(lowerJoin.left.outputSet ++ otherTableOutput)).diff(localConditions)
// Find the predicates reference either the right table or the join predicates
// between the right table and the other table.
val rightConditions = conditions.filter(_.references.
subsetOf(lowerJoin.right.outputSet ++ otherTableOutput)).diff(localConditions)

val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull)
val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull)

lowerJoin.joinType match {
case RightOuter if leftHasNonNullPredicate => Inner
case LeftOuter if rightHasNonNullPredicate => Inner
case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => Inner
case FullOuter if leftHasNonNullPredicate => LeftOuter
case FullOuter if rightHasNonNullPredicate => RightOuter
case o => o
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) =>
val newJoinType = buildNewJoinType(f, j)
if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType))
case j @ Join(child @ Join(_, _, RightOuter | LeftOuter | FullOuter, _),
subquery, LeftSemiOrAnti(joinType), joinCond) =>
val newJoinType = buildNewJoinType(j, child, subquery.outputSet)
if (newJoinType == child.joinType) j else {
Join(child.copy(joinType = newJoinType), subquery, joinType, joinCond)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,10 @@ object LeftExistence {
case _ => None
}
}

object LeftSemiOrAnti {
def unapply(joinType: JoinType): Option[JoinType] = joinType match {
case LeftSemi | LeftAnti => Some(joinType)
case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ case class Join(
left.constraints
.union(right.constraints)
.union(splitConjunctivePredicates(condition.get).toSet)
case LeftSemi if condition.isDefined =>
case LeftSemi | LeftAnti if condition.isDefined =>
left.constraints
.union(splitConjunctivePredicates(condition.get).toSet)
case j: ExistenceJoin =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ class FilterPushdownSuite extends PlanTest {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
Batch("Subquery", Once,
PullupCorrelatedPredicates,
RewritePredicateSubquery) ::
Batch("Filter Pushdown", FixedPoint(10),
CombineFilters,
PushDownPredicate,
BooleanSimplification,
PushPredicateThroughJoin,
PushLeftSemiLeftAntiThroughJoin,
CollapseProject) :: Nil
}

Expand Down Expand Up @@ -876,12 +880,15 @@ class FilterPushdownSuite extends PlanTest {
.join(y, Inner, Option("x.a".attr === "y.a".attr))
.where(Exists(z.where("x.a".attr === "z.a".attr)))
.analyze

val answer = x
.where(Exists(z.where("x.a".attr === "z.a".attr)))
.join(y, Inner, Option("x.a".attr === "y.a".attr))
.analyze
val optimized = Optimize.execute(Optimize.execute(query))
comparePlans(optimized, answer)

val optimized = Optimize.execute(query)
val expected = Optimize.execute(answer)
comparePlans(optimized, expected)
}

test("predicate subquery: push down complex") {
Expand All @@ -900,8 +907,10 @@ class FilterPushdownSuite extends PlanTest {
.join(x, Inner, Option("w.a".attr === "x.a".attr))
.join(y, LeftOuter, Option("x.a".attr === "y.a".attr))
.analyze
val optimized = Optimize.execute(Optimize.execute(query))
comparePlans(optimized, answer)

val optimized = Optimize.execute(query)
val expected = Optimize.execute(answer)
comparePlans(optimized, expected)
}

test("SPARK-20094: don't push predicate with IN subquery into join condition") {
Expand All @@ -915,13 +924,14 @@ class FilterPushdownSuite extends PlanTest {
("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr)))))
.analyze

val expectedPlan = x
val answer = x
.join(z, Inner, Some("x.b".attr === "z.b".attr))
.where("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr))))
.analyze

val optimized = Optimize.execute(queryPlan)
comparePlans(optimized, expectedPlan)
val expected = Optimize.execute(answer)
comparePlans(optimized, expected)
}

test("Window: predicate push down -- basic") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,43 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.EmptyFunctionRegistry
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.ListQuery
import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.internal.SQLConf



class RewriteSubquerySuite extends PlanTest {
object Optimize extends Optimizer(
new SessionCatalog(
new InMemoryCatalog,
EmptyFunctionRegistry,
new SQLConf()))

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Column Pruning", FixedPoint(100), ColumnPruning) ::
Batch("Rewrite Subquery", FixedPoint(1),
RewritePredicateSubquery,
ColumnPruning,
CollapseProject,
RemoveRedundantProject) :: Nil
}

test("Column pruning after rewriting predicate subquery") {
val relation = LocalRelation('a.int, 'b.int)
val relInSubquery = LocalRelation('x.int, 'y.int, 'z.int)
val schema1 = LocalRelation('a.int, 'b.int)
val schema2 = LocalRelation('x.int, 'y.int, 'z.int)

val relation = LocalRelation.fromExternalRows(schema1.output, Seq(Row(1, 1)))
val relInSubquery = LocalRelation.fromExternalRows(schema2.output, Seq(Row(1, 1, 1)))

val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a)

val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a)
val optimized = Optimize.execute(query.analyze)

val optimized = Optimize.execute(query.analyze)
val correctAnswer = relation
.select('a)
.join(relInSubquery.select('x), LeftSemi, Some('a === 'x))
.analyze
val correctAnswer = relation
.select('a)
.join(relInSubquery.select('x), LeftSemi, Some('a === 'x))
.analyze

comparePlans(optimized, correctAnswer)
comparePlans(optimized, Optimize.execute(correctAnswer))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.BooleanType

/**
* Provides helper methods for comparing plans.
Expand Down Expand Up @@ -103,7 +104,11 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite =>
val newCondition =
splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode())
.reduce(And)
Join(left, right, joinType, Some(newCondition))
val maskedJoinType = if (joinType.isInstanceOf[ExistenceJoin]) {
val exists = AttributeReference("exists", BooleanType, false)(exprId = ExprId(0))
ExistenceJoin(exists)
} else joinType
Join(left, right, maskedJoinType, Some(newCondition))
}
}

Expand Down
Loading