diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index fc1caed84e27..97424597ccd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -267,6 +267,17 @@ object ScalarSubquery { case _ => false }.isDefined } + + def hasScalarSubquery(e: Expression): Boolean = { + e.find { + case s: ScalarSubquery => true + case _ => false + }.isDefined + } + + def hasScalarSubquery(e: Seq[Expression]): Boolean = { + e.find(hasScalarSubquery(_)).isDefined + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8d251eeab848..5b9576ff081f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -61,6 +61,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ReorderJoin, EliminateOuterJoin, PushPredicateThroughJoin, + PushLeftSemiLeftAntiThroughJoin, PushDownPredicate, LimitPushDown, ColumnPruning, @@ -139,10 +140,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("LocalRelation early", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation) :: - Batch("Pullup Correlated Expressions", Once, - PullupCorrelatedPredicates) :: Batch("Subquery", Once, - OptimizeSubqueries) :: + OptimizeSubqueries, + PullupCorrelatedPredicates, + RewritePredicateSubquery) :: Batch("Replace Operators", fixedPoint, RewriteExceptAll, RewriteIntersectAll, @@ -172,13 +173,13 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // "Extract PythonUDF From JoinCondition". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :+ - Batch("RewriteSubquery", Once, - RewritePredicateSubquery, - ColumnPruning, - CollapseProject, - RemoveRedundantProject) :+ Batch("UpdateAttributeReferences", Once, - UpdateNullabilityInAttributeReferences) + UpdateNullabilityInAttributeReferences) :+ + Batch("Final column pruning", fixedPoint, + FinalColumnPruning, + CollapseProject, + RemoveRedundantProject, + ConvertToLocalRelation) } /** @@ -527,12 +528,37 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper * remove the Project p2 in the following pattern: * * p1 @ Project(_, Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(p2.inputSet) + * p1 @ Project(_, j @ Join(p2 @ Project(_, child), _, LeftSemiOrAnti(_), _)) * * p2 is usually inserted by this rule and useless, p1 could prune the columns anyway. */ object ColumnPruning extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform { + def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(FinalColumnPruning(plan)) + + /** + * The Project before Filter or LeftSemi/LeftAnti not necessary but conflict with + * PushPredicatesThroughProject, so remove it. Since the Projects have been added + * top-down, we need to remove in bottom-up order, otherwise lower Projects can be missed. + * + * While removing the projects below a self join, we should ensure that the plan remains + * valid after removing the project. The project node could have been added to de-duplicate + * the attributes and thus we need to check for this case before removing the project node. + */ + private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) + if p2.outputSet.subsetOf(child.outputSet) => + p1.copy(child = f.copy(child = child)) + + case p1 @ Project(_, j @ Join(p2 @ Project(_, child), right, LeftSemiOrAnti(_), _)) + if p2.outputSet.subsetOf(child.outputSet) && + child.outputSet.intersect(right.outputSet).isEmpty => + p1.copy(child = j.copy(left = child)) + } +} + +object FinalColumnPruning extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if !p2.outputSet.subsetOf(p.references) => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) @@ -619,7 +645,7 @@ object ColumnPruning extends Rule[LogicalPlan] { } else { p } - }) + } /** Applies a projection only when the child is producing unnecessary attributes */ private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = @@ -628,17 +654,6 @@ object ColumnPruning extends Rule[LogicalPlan] { } else { c } - - /** - * The Project before Filter is not necessary but conflict with PushPredicatesThroughProject, - * so remove it. Since the Projects have been added top-down, we need to remove in bottom-up - * order, otherwise lower Projects can be missed. - */ - private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) - if p2.outputSet.subsetOf(child.outputSet) => - p1.copy(child = f.copy(child = child)) - } } /** @@ -649,13 +664,16 @@ object CollapseProject extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, p2: Project) => - if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { + if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) || + ScalarSubquery.hasScalarSubquery(p1.projectList) || + ScalarSubquery.hasScalarSubquery(p2.projectList)) { p1 } else { p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) } case p @ Project(_, agg: Aggregate) => - if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) { + if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions) || + ScalarSubquery.hasScalarSubquery(p.projectList)) { p } else { agg.copy(aggregateExpressions = buildCleanedProjectList( @@ -984,6 +1002,28 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) + // Similar to the above Filter over Project + // LeftSemi/LeftAnti over Project + case join @ Join(p @ Project(pList, grandChild), rightOp, LeftSemiOrAnti(joinType), joinCond) + if pList.forall(_.deterministic) && !ScalarSubquery.hasScalarSubquery(pList) && + canPushThroughCondition(Seq(grandChild), joinCond, rightOp) => + if (joinCond.isEmpty) { + // No join condition, just push down the Join below Project + Project(pList, Join(grandChild, rightOp, joinType, joinCond)) + } else { + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). + val aliasMap = AttributeMap(pList.collect { + case a: Alias => (a.toAttribute, a.child) + }) + val newJoinCond = if (aliasMap.nonEmpty) { + Option(replaceAlias(joinCond.get, aliasMap)) + } else { + joinCond + } + Project(pList, Join(grandChild, rightOp, joinType, newJoinCond)) + } + case filter @ Filter(condition, aggregate: Aggregate) if aggregate.aggregateExpressions.forall(_.deterministic) && aggregate.groupingExpressions.nonEmpty => @@ -1017,6 +1057,52 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } + // Similar to the above Filter over Aggregate + // LeftSemi/LeftAnti over Aggregate + case join @ Join(aggregate: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond) => + if (joinCond.isEmpty) { + // No join condition, just push down Join below Aggregate + aggregate.copy(child = Join(aggregate.child, rightOp, joinType, joinCond)) + } else { + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) + }) + + // For each join condition, expand the alias and + // check if the condition can be evaluated using + // attributes produced by the aggregate operator's child operator. + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(joinCond.get).partition(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + val replaced = replaceAlias(cond, aliasMap) + cond.references.nonEmpty && + replaced.references.subsetOf(aggregate.child.outputSet ++ rightOp.outputSet) + } + + val stayUp = rest ++ containingNonDeterministic + + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val replaced = replaceAlias(pushDownPredicate, aliasMap) + val newAggregate = aggregate.copy(child = + Join(aggregate.child, rightOp, joinType, Option(replaced))) + // If there is no more filter to stay up, just return the Aggregate over Join. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". + if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) + } else { + // The join condition is not a subset of the Aggregate's GROUP BY columns, + // no push down. + join + } + } + // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be // pushed beneath must satisfy the following conditions: // 1. All the expressions are part of window partitioning key. The expressions can be compound. @@ -1043,10 +1129,47 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } + // Similar to the above Filter over Window + // LeftSemi/LeftAnti over Window + case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + if (joinCond.isEmpty) { + // No join condition, just push down Join below Window + w.copy(child = Join(w.child, rightOp, joinType, joinCond)) + } else { + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ + rightOp.outputSet + + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(joinCond.get).partition(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(partitionAttrs) + } + + val stayUp = rest ++ containingNonDeterministic + + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(pushDownPredicate))) + if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan) + } else { + // The join condition is not a subset of the Window's PARTITION BY clause, + // no push down. + join + } + } + case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition(_.deterministic) + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).partition(_.deterministic) + val (pushDown, rest) = candidates.partition { cond => !SubExprUtils.containsOuter(cond) } + val stayUp = rest ++ containingNonDeterministic if (pushDown.nonEmpty) { val pushDownCond = pushDown.reduceLeft(And) val output = union.output @@ -1068,6 +1191,43 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } + // Similar to the above Filter over Union + // LeftSemi/LeftAnti over Union + case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond) + if canPushThroughCondition(union.children, joinCond, rightOp) => + if (joinCond.isEmpty) { + // Push down the Join below Union + val newGrandChildren = union.children.map { grandchild => + Join(grandchild, rightOp, joinType, joinCond) + } + union.withNewChildren(newGrandChildren) + } else { + // Union could change the rows, so non-deterministic predicate can't be pushed down + val (pushDown, stayUp) = + splitConjunctivePredicates(joinCond.get).partition(_.deterministic) + + // Check if the remaining predicates do not contain columns from subquery + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) + + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { + val pushDownCond = pushDown.reduceLeft(And) + val output = union.output + val newGrandChildren = union.children.map { grandchild => + val newCond = pushDownCond transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) + Join(grandchild, rightOp, joinType, Option(newCond)) + } + val newUnion = union.withNewChildren(newGrandChildren) + if (stayUp.isEmpty) newUnion else Filter(stayUp.reduceLeft(And), newUnion) + } else { + // Nothing to push down + join + } + } + case filter @ Filter(condition, watermark: EventTimeWatermark) => val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { p => p.deterministic && !p.references.contains(watermark.eventTime) @@ -1083,11 +1243,33 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } - case filter @ Filter(_, u: UnaryNode) - if canPushThrough(u) && u.expressions.forall(_.deterministic) => + case filter @ Filter(condition, u: UnaryNode) + if canPushThrough(u) && u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) } + + // Similar to the above Filter over UnaryNode + // LeftSemi/LeftAnti over UnaryNode + case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond) + if canPushThrough(u) => + pushDownJoin(join, u.child) { joinCond => + u.withNewChildren(Seq(Join(u.child, rightOp, joinType, Option(joinCond)))) + } + } + + /** + * Check if we can safely push a join through a project or union by making sure that predicate + * subqueries in the condition do not contain the same attributes as the plan they are moved + * into. This can happen when the plan and predicate subquery have the same source. + */ + private def canPushThroughCondition(plans: Seq[LogicalPlan], condition: Option[Expression], + rightOp: LogicalPlan): Boolean = { + val attributes = AttributeSet(plans.flatMap (_.output)) + if (condition.isDefined) { + val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes) + matched.isEmpty + } else true } private def canPushThrough(p: UnaryNode): Boolean = p match { @@ -1106,20 +1288,20 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } private def pushDownPredicate( - filter: Filter, - grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + filter: Filter, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { // Only push down the predicates that is deterministic and all the referenced attributes // come from grandchild. // TODO: non-deterministic predicates could be pushed through some operators that do not change // the rows. - val (candidates, nonDeterministic) = - splitConjunctivePredicates(filter.condition).partition(_.deterministic) + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(filter.condition).partition(_.deterministic) val (pushDown, rest) = candidates.partition { cond => cond.references.subsetOf(grandchild.outputSet) } - val stayUp = rest ++ nonDeterministic + val stayUp = rest ++ containingNonDeterministic if (pushDown.nonEmpty) { val newChild = insertFilter(pushDown.reduceLeft(And)) @@ -1133,6 +1315,36 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } + private def pushDownJoin( + join: Join, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + // Only push down the join when join condition deterministic and all the referenced attributes + // come from childen of left and right legs of join. + val (candidates, containingNonDeterministic) = if (join.condition.isDefined) { + splitConjunctivePredicates(join.condition.get).partition(_.deterministic) + } else { + (Nil, Nil) + } + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(grandchild.outputSet ++ join.right.outputSet) + } + + val stayUp = rest ++ containingNonDeterministic + + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + join + } + } + + /** * Check if we can safely push a filter through a projection, by making sure that predicate * subqueries in the condition do not contain the same attributes as the plan they are moved @@ -1168,13 +1380,17 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { - val (pushDownCandidates, nonDeterministic) = condition.partition(_.deterministic) + val (candidates, nonDeterministic) = condition.partition(_.deterministic) + val (pushDownCandidates, subquery) = candidates.partition { cond => + !SubExprUtils.containsOuter(cond) + } val (leftEvaluateCondition, rest) = pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) val (rightEvaluateCondition, commonCondition) = rest.partition(expr => expr.references.subsetOf(right.outputSet)) - (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ nonDeterministic) + (leftEvaluateCondition, rightEvaluateCondition, + subquery ++ commonCondition ++ nonDeterministic) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -1262,6 +1478,98 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Pushes down a subquery, in the form of [[Join LeftSemi/LeftAnti]] operator + * to the left or right side of a join below. + */ +object PushLeftSemiLeftAntiThroughJoin extends Rule[LogicalPlan] with PredicateHelper { + /** + * Define an enumeration to identify whether a Exists/In subquery, + * in the form of a LeftSemi/LeftAnti, can be pushed down to + * the left table or the right table. + */ + object subqueryPushdown extends Enumeration { + val toRightTable, toLeftTable, none = Value + } + + /** + * Determine which side of the join an Exists/In subquery (in the form of + * LeftSemi/LeftAnti join) can be pushed down to. + */ + private def pushTo(child: Join, subquery: LogicalPlan, joinCond: Option[Expression]) = { + val left = child.left + val right = child.right + val joinType = child.joinType + val subqueryOutput = subquery.outputSet + + if (joinCond.nonEmpty) { + val noPushdown = (subqueryPushdown.none, None) + val conditions = splitConjunctivePredicates(joinCond.get) + val (candidates, containingNonDeterministic) = conditions.partition(_.deterministic) + lazy val (leftConditions, rest) = + candidates.partition(_.references.subsetOf(left.outputSet ++ subqueryOutput)) + lazy val (rightConditions, commonConditions) = + rest.partition(_.references.subsetOf(right.outputSet ++ subqueryOutput)) + + if (containingNonDeterministic.nonEmpty) { + noPushdown + } else { + if (rest.isEmpty && leftConditions.nonEmpty) { + // When all the join conditions are only between left table and the subquery + // push the subquery to the left table. + (subqueryPushdown.toLeftTable, leftConditions.reduceLeftOption(And)) + } else if (leftConditions.isEmpty && rightConditions.nonEmpty && commonConditions.isEmpty) { + // When all the join conditions are only between right table and the subquery + // push the subquery to the right table. + (subqueryPushdown.toRightTable, rightConditions.reduceLeftOption(And)) + } else { + noPushdown + } + } + } else { + /** + * When there is no correlated predicate, + * 1) if this is a left outer join, push the subquery down to the left table + * 2) if a right outer join, to the right table, + * 3) if an inner join, push to either side. + */ + val action = joinType match { + case RightOuter => + subqueryPushdown.toRightTable + case _: InnerLike | LeftOuter => + subqueryPushdown.toLeftTable + case _ => + subqueryPushdown.none + } + (action, None) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // push LeftSemi/LeftAnti down into the join below + case j @ Join(child @ Join(left, right, _ : InnerLike | LeftOuter | RightOuter, belowJoinCond), + subquery, LeftSemiOrAnti(joinType), joinCond) => + val belowJoinType = child.joinType + val (action, newJoinCond) = pushTo(child, subquery, joinCond) + + action match { + case subqueryPushdown.toLeftTable + if (belowJoinType == LeftOuter || belowJoinType.isInstanceOf[InnerLike]) => + // push down the subquery to the left table + val newLeft = Join(left, subquery, joinType, newJoinCond) + Join(newLeft, right, belowJoinType, belowJoinCond) + case subqueryPushdown.toRightTable + if (belowJoinType == RightOuter || belowJoinType.isInstanceOf[InnerLike]) => + // push down the subquery to the right table + val newRight = Join(right, subquery, joinType, newJoinCond) + Join(left, newRight, belowJoinType, belowJoinCond) + case _ => + // Do nothing + j + } + } +} + /** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 6ebb194d71c2..ea37afbd8124 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -120,7 +120,11 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { * Returns whether the expression returns null or false when all inputs are nulls. */ private def canFilterOutNull(e: Expression): Boolean = { - if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false + if (!e.deterministic || + SubqueryExpression.hasCorrelatedSubquery(e) || + SubExprUtils.containsOuter(e)) { + return false + } val attributes = e.references.toSeq val emptyRow = new GenericInternalRow(attributes.length) val boundE = BindReferences.bindReference(e, attributes) @@ -147,10 +151,45 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { } } + private def buildNewJoinType( + upperJoin: Join, + lowerJoin: Join, + otherTableOutput: AttributeSet): JoinType = { + val conditions = upperJoin.constraints + // Find the predicates reference only on the other table. + val localConditions = conditions.filter(_.references.subsetOf(otherTableOutput)) + // Find the predicates reference either the left table or the join predicates + // between the left table and the other table. + val leftConditions = conditions.filter(_.references. + subsetOf(lowerJoin.left.outputSet ++ otherTableOutput)).diff(localConditions) + // Find the predicates reference either the right table or the join predicates + // between the right table and the other table. + val rightConditions = conditions.filter(_.references. + subsetOf(lowerJoin.right.outputSet ++ otherTableOutput)).diff(localConditions) + + val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) + val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) + + lowerJoin.joinType match { + case RightOuter if leftHasNonNullPredicate => Inner + case LeftOuter if rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate => LeftOuter + case FullOuter if rightHasNonNullPredicate => RightOuter + case o => o + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => val newJoinType = buildNewJoinType(f, j) if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) + case j @ Join(child @ Join(_, _, RightOuter | LeftOuter | FullOuter, _), + subquery, LeftSemiOrAnti(joinType), joinCond) => + val newJoinType = buildNewJoinType(j, child, subquery.outputSet) + if (newJoinType == child.joinType) j else { + Join(child.copy(joinType = newJoinType), subquery, joinType, joinCond) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index c77849035a97..86cdc261329c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -114,3 +114,10 @@ object LeftExistence { case _ => None } } + +object LeftSemiOrAnti { + def unapply(joinType: JoinType): Option[JoinType] = joinType match { + case LeftSemi | LeftAnti => Some(joinType) + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a26ec4eed864..6f273543e0b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -314,7 +314,7 @@ case class Join( left.constraints .union(right.constraints) .union(splitConjunctivePredicates(condition.get).toSet) - case LeftSemi if condition.isDefined => + case LeftSemi | LeftAnti if condition.isDefined => left.constraints .union(splitConjunctivePredicates(condition.get).toSet) case j: ExistenceJoin => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 82a10254d846..9805ff6f889b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -34,11 +34,15 @@ class FilterPushdownSuite extends PlanTest { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: + Batch("Subquery", Once, + PullupCorrelatedPredicates, + RewritePredicateSubquery) :: Batch("Filter Pushdown", FixedPoint(10), CombineFilters, PushDownPredicate, BooleanSimplification, PushPredicateThroughJoin, + PushLeftSemiLeftAntiThroughJoin, CollapseProject) :: Nil } @@ -876,12 +880,15 @@ class FilterPushdownSuite extends PlanTest { .join(y, Inner, Option("x.a".attr === "y.a".attr)) .where(Exists(z.where("x.a".attr === "z.a".attr))) .analyze + val answer = x .where(Exists(z.where("x.a".attr === "z.a".attr))) .join(y, Inner, Option("x.a".attr === "y.a".attr)) .analyze - val optimized = Optimize.execute(Optimize.execute(query)) - comparePlans(optimized, answer) + + val optimized = Optimize.execute(query) + val expected = Optimize.execute(answer) + comparePlans(optimized, expected) } test("predicate subquery: push down complex") { @@ -900,8 +907,10 @@ class FilterPushdownSuite extends PlanTest { .join(x, Inner, Option("w.a".attr === "x.a".attr)) .join(y, LeftOuter, Option("x.a".attr === "y.a".attr)) .analyze - val optimized = Optimize.execute(Optimize.execute(query)) - comparePlans(optimized, answer) + + val optimized = Optimize.execute(query) + val expected = Optimize.execute(answer) + comparePlans(optimized, expected) } test("SPARK-20094: don't push predicate with IN subquery into join condition") { @@ -915,13 +924,14 @@ class FilterPushdownSuite extends PlanTest { ("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr))))) .analyze - val expectedPlan = x + val answer = x .join(z, Inner, Some("x.b".attr === "z.b".attr)) .where("x.a".attr > 1 || "z.c".attr.in(ListQuery(w.select("w.d".attr)))) .analyze val optimized = Optimize.execute(queryPlan) - comparePlans(optimized, expectedPlan) + val expected = Optimize.execute(answer) + comparePlans(optimized, expected) } test("Window: predicate push down -- basic") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index 6b3739c372c3..532af60c3460 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -17,39 +17,43 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.EmptyFunctionRegistry +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.ListQuery import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.internal.SQLConf + class RewriteSubquerySuite extends PlanTest { + object Optimize extends Optimizer( + new SessionCatalog( + new InMemoryCatalog, + EmptyFunctionRegistry, + new SQLConf())) - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Column Pruning", FixedPoint(100), ColumnPruning) :: - Batch("Rewrite Subquery", FixedPoint(1), - RewritePredicateSubquery, - ColumnPruning, - CollapseProject, - RemoveRedundantProject) :: Nil - } test("Column pruning after rewriting predicate subquery") { - val relation = LocalRelation('a.int, 'b.int) - val relInSubquery = LocalRelation('x.int, 'y.int, 'z.int) + val schema1 = LocalRelation('a.int, 'b.int) + val schema2 = LocalRelation('x.int, 'y.int, 'z.int) + + val relation = LocalRelation.fromExternalRows(schema1.output, Seq(Row(1, 1))) + val relInSubquery = LocalRelation.fromExternalRows(schema2.output, Seq(Row(1, 1, 1))) + + val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a) - val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a) + val optimized = Optimize.execute(query.analyze) - val optimized = Optimize.execute(query.analyze) - val correctAnswer = relation - .select('a) - .join(relInSubquery.select('x), LeftSemi, Some('a === 'x)) - .analyze + val correctAnswer = relation + .select('a) + .join(relInSubquery.select('x), LeftSemi, Some('a === 'x)) + .analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized, Optimize.execute(correctAnswer)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 3081ff935f04..b51658296fe7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BooleanType /** * Provides helper methods for comparing plans. @@ -103,7 +104,11 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => val newCondition = splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) .reduce(And) - Join(left, right, joinType, Some(newCondition)) + val maskedJoinType = if (joinType.isInstanceOf[ExistenceJoin]) { + val exists = AttributeReference("exists", BooleanType, false)(exprId = ExprId(0)) + ExistenceJoin(exists) + } else joinType + Join(left, right, maskedJoinType, Some(newCondition)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala new file mode 100644 index 000000000000..e199ed958e79 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LeftSemiOrAntiPushdownSuite.scala @@ -0,0 +1,775 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +/* + * Writing test cases using combinatorial testing technique + * Dimension 1: (A) Exists or (B) In + * Dimension 2: (A) LeftSemi, (B) LeftAnti, or (C) ExistenceJoin + * Dimension 3: (A) Join over Project, (B) Join over Agg, (C) Join over Window, + * (D) Join over Union, or (E) Join over other UnaryNode + * Dimension 4: (A) join condition is column or (B) expression + * Dimension 5: Subquery is (A) a single table, or (B) more than one table + * Dimension 6: Parent side is (A) a single table, or (B) more than one table + */ +class LeftSemiOrAntiPushdownSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Join} + import org.apache.spark.sql.catalyst.plans.LeftSemiOrAnti + + setupTestData() + + val row = identity[(java.lang.Integer, java.lang.Integer, java.lang.Integer)](_) + + lazy val t1 = Seq( + row((1, 1, 1)), + row((1, 2, 2)), + row((2, 1, null)), + row((3, 1, 2)), + row((null, 0, 3)), + row((4, null, 2)), + row((0, -1, null))).toDF("t1a", "t1b", "t1c") + + lazy val t2 = Seq( + row((1, 1, 1)), + row((2, 1, 1)), + row((2, 1, null)), + row((3, 3, 3)), + row((3, 1, 0)), + row((null, null, 1)), + row((0, 0, -1))).toDF("t2a", "t2b", "t2c") + + lazy val t3 = Seq( + row((1, 1, 1)), + row((2, 1, 0)), + row((2, 1, null)), + row((10, 4, -1)), + row((3, 2, 0)), + row((-2, 1, -1)), + row((null, null, null))).toDF("t3a", "t3b", "t3c") + + lazy val t4 = Seq( + row((1, 1, 2)), + row((1, 2, 1)), + row((2, 1, null))).toDF("t4a", "t4b", "t4c") + + lazy val t5 = Seq( + row((1, 1, 1)), + row((2, null, 0)), + row((2, 1, null))).toDF("t5a", "t5b", "t5c") + + protected override def beforeAll(): Unit = { + super.beforeAll() + t1.createOrReplaceTempView("t1") + t2.createOrReplaceTempView("t2") + t3.createOrReplaceTempView("t3") + t4.createOrReplaceTempView("t4") + t5.createOrReplaceTempView("t5") + } + + private def checkLeftSemiOrAntiPlan(plan: LogicalPlan): Unit = { + plan match { + case j@Join(_, _, LeftSemiOrAnti(_), _) => + // This is the expected result. + case _ => + fail( + s""" + |== FAIL: Top operator must be a LeftSemi or LeftAnti === + |${plan.toString} + """.stripMargin) + } + } + + /** + * TC 1.1: 1A-2B-3A-4B-5A-6A + * Expected result: LeftAnti below Project + * Note that the expression T1A+1 is evaluated twice in Join and Project + * + * TC 1.1.1: Comparing to Inner, we do not push down Inner join under Project + * + * SELECT TX.* + * FROM (SELECT T1A+1 T1A1, T1B + * FROM T1 + * WHERE T1A > 2) TX, T2 + * WHERE T2A = T1A1 + */ + test("TC 1.1: LeftSemi/LeftAnti over Project") { + val plan1 = + sql( + """ + | select * + | from (select t1a+1 t1a1, t1b + | from t1 + | where t1a > 2) tx + | where not exists (select 1 + | from t2 + | where t2a = t1a1) + """.stripMargin) + val plan2 = + sql( + """ + | select t1a+1 t1a1, t1b + | from t1 + | where t1a > 2 + | and not exists (select 1 + | from t2 + | where t2a = t1a+1) + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + + /** + * TC 1.2: 1B-2A-3B-4B-5B-6A + * Expected result: LeftSemi below Aggregate + */ + test("TC 1.2: LeftSemi/LeftAnti over Aggregate") { + val plan1 = + sql( + """ + | select * + | from (select sum(t1a), coalesce(t1c, 0) t1c_expr + | from t1 + | group by coalesce(t1c, 0)) tx + | where t1c_expr in (select t2b + | from t2, t3 + | where t2a = t3a) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select sum(t1a), coalesce(t1c, 0) t1c_expr + | from t1 + | where coalesce(t1c, 0) in (select t2b + | from t2, t3 + | where t2a = t3a) + | group by coalesce(t1c, 0)) tx + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + + /** + * TC 1.3: 1A-2A-3C-4B-5A-6A + * Expected result: LeftSemi below Window + * + * Variations that yield no push down + * + * TC 1.3.1: We do not match T1B1 to the expression T1B+1 in the PARTITION BY clause + * hence no push down. + * + * SELECT * + * FROM (SELECT T1B+1 as T1B1, SUM(T1B * T1A) OVER (PARTITION BY T1B+1) SUM + * FROM T1) TX + * WHERE EXISTS (SELECT 1 FROM T2 WHERE T2B = TX.T1B1) + * + * TC 1.3.2: With the additional column Exists from the ExistenceJoin that does not exist + * in Window, and we do not add a compensation, the result is + * we don't push down ExistenceJoin under a Window. + * + * SELECT * + * FROM (SELECT T1B, SUM(T1B * T1A) OVER (PARTITION BY T1B) SUM + * FROM T1) TX + * WHERE EXISTS (SELECT 1 FROM T2 WHERE T2B = TX.T1B) + * OR T1B1 > 1 + */ + test("TC 1.3: LeftSemi/LeftAnti over Window") { + + val plan1 = + sql( + """ + | select * + | from (select t1b, sum(t1b * t1a) over (partition by t1b) sum + | from t1) tx + | where exists (select 1 + | from t2 + | where t2b = tx.t1b) + """.stripMargin) + + val plan2 = + sql( + """ + | select * + | from (select t1b, sum(t1b * t1a) over (partition by t1b) sum + | from t1 + | where exists (select 1 + | from t2 + | where t2b = t1.t1b)) tx + """.stripMargin) + + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + + /** + * TC 1.4: 1B-2B-3D-4A-5B-6B + * Expected result: LeftAnti below Union + */ + test("TC 1.4: LeftSemi/LeftAnti over Union") { + val plan1 = + sql( + """ + | select * + | from (select t1a, t1b, t1c + | from t1, t3 + | where t1a = t3a + | union all + | select t2a, t2b, t2c + | from t2, t3 + | where t2a = t3a) ua + | where t1c not in (select t4c + | from t5, t4 + | where t5.t5b = t4.t4b) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select t1a, t1b, t1c + | from t1, t3 + | where t1a = t3a + | and t1c not in (select t4c + | from t5, t4 + | where t5.t5b = t4.t4b) + | union all + | select t2a, t2b, t2c + | from t2, t3 + | where t2a = t3a + | and t2c not in (select t4c + | from t5, t4 + | where t5.t5b = t4.t4b) + | ) ua + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + + /** + * TC 1.5: 1B-2B-3E-4B-5A-6B + * Expected result: LeftAnti below Sort + */ + test("TC 1.5: LeftSemi/LeftAnti over other UnaryNode") { + val plan1 = + sql( + """ + | select * + | from (select t1a+1 t1a1, t1b, t3c + | from t1, t3 + | where t1b = t3b + | and t1a < 3 + | order by t1b) tx + | where tx.t1a1 not in (select t2a + | from t2 + | where t2b < 3 + | and tx.t3c >= 0) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select t1a+1 t1a1, t1b, t3c + | from t1, t3 + | where t1b = t3b + | and t1a < 3 + | and t1.t1a+1 not in (select t2a + | from t2 + | where t2b < 3 + | and t3c >= 0) + | order by t1b) tx + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + + /** + * LeftSemi/LeftAnti over join + * + * Dimension 1: (A) LeftSemi or (B) LeftAnti + * Dimension 2: Join below is (A) Inner (B) LeftOuter (C) RightOuter (D) FullOuter, or, + * (E) LeftSemi/LeftAnti + * Dimension 3: Subquery correlated to (A) left table (B) right table, (C) both tables, + * or, (D) no correlated predicate + */ + /** + * TC 2.1: 1A-2A-3A + * Expected result: LeftSemi join below Inner join + */ + test("TC 2.1: LeftSemi over inner join") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 inner join t2 on t1b = t2b and t2a >= 2) + | select * + | from join + | where t1a in (select t3a from t3 where t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where t1a in (select t3a from t3 where t3b >= 1)) t1 + | inner join t2 + | on t1b = t2b and t2a >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + /** + * TC 2.2: 1A-2B-3A + * Expected result: LeftSemi join below LeftOuter join + */ + test("TC 2.2: LeftSemi over left outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 left join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | left join t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + /** + * TC 2.3: 1B-2B-3A + * Expected result: LeftAnti join below LeftOuter join + */ + test("TC 2.3: LeftAnti over left outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 left join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | left join t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + /** + * TC 2.4: 1A-2C-3A + * Expected result: LeftSemi join below Inner join + */ + test("TC 2.4: LeftSemi over right outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c is null) + | select * + | from join + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | inner join t2 + | on t1b = t2b and t2c is null + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + /** + * TC 2.5: 1B-2C-3B + * Expected result: LeftAnti join below RightOuter join + * RightOuter does not convert to Inner because NOT IN can return null. + */ + test("TC 2.5: LeftAnti over right outer join with correlated columns on the right table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where t2a not in (select t3a from t3 where t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from t1 + | right join + | (select * + | from t2 + | where t2a not in (select t3a from t3 where t3b >= 1)) t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + /** + * TC 2.6: 1B-2C-3C + * Expected result: No push down + */ + test("TC 2.6: LeftAnti over right outer join with correlated cols on both left and right tbls") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3a = t1a and t3b > t2b) + """.stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | left anti join + | (select t3a, t3b + | from t3 + | where t3a is not null + | and t3b is not null) t3 + | on t3a = t1a and t3b > t2b + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + } + /** + * TC 2.7: 1B-2D-3A + * Expected result: LeftAnti join below LeftOuter join + */ + test("TC 2.7: LeftAnti over full outer join with correlated columns on the left table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from (select * + | from t1 + | where not exists (select 1 from t3 where t3a = t1a and t3b >= 1)) t1 + | left join t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + /** + * TC 2.8: 1A-2D-3B + * Expected result: LeftSemi join below RightOuter join + */ + test("TC 2.8: LeftSemi over full outer join with correlated columns on the right table") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where exists (select 1 from t3 where t3a = t2a and t3b >= 1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from t1 + | right join + | (select * + | from t2 + | where exists (select 1 from t3 where t3a = t2a and t3b >= 1)) t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + /** + * TC 2.9: 1A-2E-3A + * Expected result: No push down + */ + test("TC 2.9: LeftSemi over left semi join with correlated columns on the left table") { + import org.apache.spark.sql.catalyst.plans.logical.Union + val plan1 = + sql( + """ + | with join as + | (select * from t1 left semi join t2 on t1b = t2b and t2c >= 0) + | select * + | from join + | where exists (select 1 + | from (select * from t3 + | union all + | select * from t4) t3 + | where t3a = t1a and t3c is not null) + """.stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * + | from t1 + | left semi join t2 + | on t1b = t2b and t2c >= 0) + | select * + | from join + | left semi join + | (select * from t3 + | union all + | select * from t4) t3 + | on t3a = t1a and t3c is not null + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan.collectFirst { + case j @ Join(_, _, LeftSemiOrAnti(_), _) => j + } + optPlan match { + case Some(j@Join(_, _: Union, LeftSemiOrAnti(_), _)) => + // This is the expected result. + case _ => + fail( + s""" + |== FAIL: The right operand of the top operator must be a Union === + |${optPlan.toString} + """.stripMargin) + } + } + /** + * TC 2.10: 1A-2A-3C + * Expected result: No push down + */ + test("TC 2.10: LeftSemi over inner join with correlated columns on both left and right tables") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 inner join t2 on t1b = t2b and t2c is null) + | select * + | from join + | where exists (select 1 from t3 where t3a = t1a and t3a = t2a) + """.stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * + | from t1 + | inner join t2 + | on t1b = t2b and t2c is null) + | select * + | from join + | left semi join t3 + | on t3a = t1a and t3a = t2a + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + } + /** + * TC 2.11: 1B-2C-3D + * Expected result: LeftSemi join below RightOuter join + */ + test("TC 2.11: LeftAnti over right outer join with no correlated columns") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 right join t2 on t1b = t2b and t2c >= 2) + | select * + | from join + | where not exists (select 1 from t3 where t3b < -1) + """.stripMargin) + val plan2 = + sql( + """ + | select * + | from t1 + | right outer join + | (select * + | from t2 + | where not exists (select 1 from t3 where t3b < -1)) t2 + | on t1b = t2b and t2c >= 2 + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + /** + * TC 2.12: 1B-2D-3D + * Expected result: LeftSemi join below RightOuter join + */ + test("TC 2.12: LeftAnti over full outer join with no correlated columns") { + val plan1 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 0) + | select * + | from join + | where not exists (select 1 from t3 where t3b < -1) + | and (t1c = 1 or t1c is null) + """.stripMargin) + val plan2 = + sql( + """ + | with join as + | (select * from t1 full join t2 on t1b = t2b and t2c >= 0) + | select * + | from join + | left anti join t3 + | on t3b < -1 + | where (t1c = 1 or t1c is null) + """.stripMargin) + checkAnswer(plan1, plan2) + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + /** + * TC 3.1: Negative case - LeftSemi over Aggregate + * Expected result: No push down + */ + test("TC 3.1: Negative case - LeftSemi over Aggregate") { + val plan1 = + sql( + """ + | select t1b, min(t1a) as min + | from t1 b + | group by t1b + | having t1b in (select t1b+1 + | from t1 a + | where a.t1a = min(b.t1a) ) + """.stripMargin) + val plan2 = + sql( + """ + | select b.* + | from (select t1b, min(t1a) as min + | from t1 + | group by t1b) b + | left semi join t1 + | on b.t1b = t1.t1b+1 + | and b.min = t1.t1a + | and t1.t1a is not null + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + } + /** + * TC 3.2: Negative case - LeftAnti over Window + * Expected result: No push down + */ + test("TC 3.2: Negative case - LeftAnti over Window") { + + val plan1 = + sql( + """ + | select b.t1b, b.min + | from (select t1b, min(t1a) over (partition by t1b) min + | from t1) b + | where not exists (select 1 + | from t1 a + | where a.t1a = b.min + | and a.t1b = b.t1b) + """.stripMargin) + + val plan2 = + sql( + """ + | select b.t1b, b.min + | from (select t1b, min(t1a) over (partition by t1b) min + | from t1) b + | left anti join t1 a + | on a.t1a = b.min + | and a.t1b = b.t1b + """.stripMargin) + + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + } + /** + * TC 3.3: Negative case - LeftSemi over Union + * Expected result: No push down + */ + test("TC 3.3: Negative case - LeftSemi over Union") { + val plan1 = + sql( + """ + | select un.t2b, un.t2a + | from (select t2b, t2a + | from t2 + | union all + | select t3b, t3a + | from t3) un + | where exists (select 1 + | from t1 a + | where a.t1b = un.t2b + | and a.t1a = un.t2a + case when rand() < 0 then 1 else 0 end) + """.stripMargin) + val plan2 = + sql( + """ + | select un.t2b, un.t2a + | from (select t2b, t2a + | from t2 + | union all + | select t3b, t3a + | from t3) un + | left semi join t1 a + | on a.t1b = un.t2b + | and a.t1a = un.t2a + """.stripMargin) + checkAnswer(plan1, plan2) + val optPlan = plan1.queryExecution.optimizedPlan + checkLeftSemiOrAntiPlan(optPlan) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 09ad0fdd6636..39666b66434a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -169,7 +169,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { val plan = df.queryExecution.executedPlan assert(!plan.find(p => - p.isInstanceOf[WholeStageCodegenExec] && + p.isInstanceOf[WholeStageCodegenExec] && p.isInstanceOf[SortMergeJoinExec] && p.asInstanceOf[WholeStageCodegenExec].child.children(0) .isInstanceOf[SortMergeJoinExec]).isDefined) assert(df.collect() === Array(Row(1), Row(2))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 0f1d08b6af5d..3a820ab95456 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -396,10 +396,10 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") // Assume the execution plan is - // ... -> BroadcastHashJoin(nodeId = 0) + // ... -> BroadcastHashJoin(nodeId = 1) val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") testSparkPlanMetrics(df, 2, Map( - 0L -> (("BroadcastHashJoin", Map( + 1L -> (("BroadcastHashJoin", Map( "number of output rows" -> 2L)))) ) }