diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 4e07e729ab3bb..2bedf84271585 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -140,14 +140,11 @@ object SubExprUtils extends PredicateHelper { * Given a logical plan, returns TRUE if it has an outer reference and false otherwise. */ def hasOuterReferences(plan: LogicalPlan): Boolean = { - plan.find { - case f: Filter => containsOuter(f.condition) - case other => false - }.isDefined + plan.find(_.expressions.exists(containsOuter)).isDefined } /** - * Given a list of expressions, returns the expressions which have outer references. Aggregate + * Given an expression, returns the expressions which have outer references. Aggregate * expressions are treated in a special way. If the children of aggregate expression contains an * outer reference, then the entire aggregate expression is marked as an outer reference. * Example (SQL): @@ -183,18 +180,18 @@ object SubExprUtils extends PredicateHelper { * }}} * The code below needs to change when we support the above cases. */ - def getOuterReferences(conditions: Seq[Expression]): Seq[Expression] = { + def getOuterReferences(expr: Expression): Seq[Expression] = { val outerExpressions = ArrayBuffer.empty[Expression] - conditions foreach { expr => - expr transformDown { - case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) => - val newExpr = stripOuterReference(a) - outerExpressions += newExpr - newExpr - case OuterReference(e) => - outerExpressions += e - e - } + expr transformDown { + case a: AggregateExpression if a.collectLeaves.forall(_.isInstanceOf[OuterReference]) => + // Collect and update the sub-tree so that outer references inside this aggregate + // expression will not be collected. For example: min(outer(a)) -> min(a). + val newExpr = stripOuterReference(a) + outerExpressions += newExpr + newExpr + case OuterReference(e) => + outerExpressions += e + e } outerExpressions.toSeq } @@ -204,8 +201,7 @@ object SubExprUtils extends PredicateHelper { * Filter operator can host outer references. */ def getOuterReferences(plan: LogicalPlan): Seq[Expression] = { - val conditions = plan.collect { case Filter(cond, _) => cond } - getOuterReferences(conditions) + plan.flatMap(_.expressions.flatMap(getOuterReferences)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala new file mode 100644 index 0000000000000..377dcd6666d64 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ + +/** + * Decorrelate the inner query by eliminating outer references and create domain joins. + * The implementation is based on the paper: Unnesting Arbitrary Queries by Thomas Neumann + * and Alfons Kemper. https://dl.gi.de/handle/20.500.12116/2418. + * + * A correlated subquery can be viewed as a "dependent" nested loop join between the outer and + * the inner query. For each row produced by the outer query, we bind the [[OuterReference]]s in + * in the inner query with the corresponding values in the row, and then evaluate the inner query. + * + * Dependent Join + * :- Outer Query + * +- Inner Query + * + * If the [[OuterReference]]s are bound to the same value, the inner query will return the same + * result. Based on this, we can reduce the times to evaluate the inner query by first getting + * all distinct values of the [[OuterReference]]s. + * + * Normal Join + * :- Outer Query + * +- Dependent Join + * :- Inner Query + * +- Distinct Aggregate (outer_ref1, outer_ref2, ...) + * +- Outer Query + * + * The distinct aggregate of the outer references is called a "domain", and the dependent join + * between the inner query and the domain is called a "domain join". We need to push down the + * domain join through the inner query until there is no outer reference in the sub-tree and + * the domain join will turn into a normal join. + * + * The decorrelation function returns a new query plan with optional placeholder [[DomainJoins]]s + * added and a list of join conditions with the outer query. [[DomainJoin]]s need to be rewritten + * into actual inner join between the inner query sub-tree and the outer query. + * + * E.g. decorrelate an inner query with equality predicates: + * + * SELECT (SELECT MIN(b) FROM t1 WHERE t2.c = t1.a) FROM t2 + * + * Aggregate [] [min(b)] Aggregate [a] [min(b), a] + * +- Filter (outer(c) = a) => +- Relation [t1] + * +- Relation [t1] + * + * Join conditions: [c = a] + * + * E.g. decorrelate an inner query with non-equality predicates: + * + * SELECT (SELECT MIN(b) FROM t1 WHERE t2.c > t1.a) FROM t2 + * + * Aggregate [] [min(b)] Aggregate [c'] [min(b), c'] + * +- Filter (outer(c) > a) => +- Filter (c' > a) + * +- Relation [t1] +- DomainJoin [c'] + * +- Relation [t1] + * + * Join conditions: [c <=> c'] + */ +object DecorrelateInnerQuery extends PredicateHelper { + + /** + * Check if an expression contains any attribute. Note OuterReference is a + * leaf node and will not be found here. + */ + private def containsAttribute(expression: Expression): Boolean = { + expression.find(_.isInstanceOf[Attribute]).isDefined + } + + /** + * Check if an expression can be pulled up over an [[Aggregate]] without changing the + * semantics of the plan. The expression must be an equality predicate that guarantees + * one-to-one mapping between inner and outer attributes. More specifically, one side + * of the predicate must be an attribute and another side of the predicate must not + * contain other attributes from the inner query. + * For example: + * (a = outer(c)) -> true + * (a > outer(c)) -> false + * (a + b = outer(c)) -> false + * (a = outer(c) - b) -> false + */ + private def canPullUpOverAgg(expression: Expression): Boolean = expression match { + case Equality(_: Attribute, b) => !containsAttribute(b) + case Equality(a, _: Attribute) => !containsAttribute(a) + case o => !containsAttribute(o) + } + + /** + * Collect outer references in an expressions that are in the output attributes of the outer plan. + */ + private def collectOuterReferences(expression: Expression): AttributeSet = { + AttributeSet(expression.collect { case o: OuterReference => o.toAttribute }) + } + + /** + * Collect outer references in a sequence of expressions that are in the output attributes + * of the outer plan. + */ + private def collectOuterReferences(expressions: Seq[Expression]): AttributeSet = { + AttributeSet.fromAttributeSets(expressions.map(collectOuterReferences)) + } + + /** + * Build a mapping between outer references with equivalent inner query attributes. + * E.g. [outer(a) = x, y = outer(b), outer(c) = z + 1] => {a -> x, b -> y} + */ + private def collectEquivalentOuterReferences( + expressions: Seq[Expression]): Map[Attribute, Attribute] = { + expressions.collect { + case Equality(o: OuterReference, a: Attribute) => (o.toAttribute, a.toAttribute) + case Equality(a: Attribute, o: OuterReference) => (o.toAttribute, a.toAttribute) + }.toMap + } + + /** + * Replace all outer references using the expressions in the given outer reference map. + */ + private def replaceOuterReference[E <: Expression]( + expression: E, + outerReferenceMap: Map[Attribute, Attribute]): E = { + expression.transform { + case o: OuterReference => outerReferenceMap.getOrElse(o.toAttribute, o) + }.asInstanceOf[E] + } + + /** + * Replace all outer references in the given expressions using the expressions in the + * outer reference map. + */ + private def replaceOuterReferences[E <: Expression]( + expressions: Seq[E], + outerReferenceMap: Map[Attribute, Attribute]): Seq[E] = { + expressions.map(replaceOuterReference(_, outerReferenceMap)) + } + + /** + * Return all references that are presented in the join conditions but not in the output + * of the given named expressions. + */ + private def missingReferences( + namedExpressions: Seq[NamedExpression], + joinCond: Seq[Expression]): AttributeSet = { + val output = namedExpressions.map(_.toAttribute) + AttributeSet(joinCond.flatMap(_.references)) -- AttributeSet(output) + } + + /** + * Deduplicate the inner and the outer query attributes and return an aliased + * subquery plan and join conditions if duplicates are found. Duplicated attributes + * can break the structural integrity when joining the inner and outer plan together. + */ + def deduplicate( + innerPlan: LogicalPlan, + conditions: Seq[Expression], + outerOutputSet: AttributeSet): (LogicalPlan, Seq[Expression]) = { + val duplicates = innerPlan.outputSet.intersect(outerOutputSet) + if (duplicates.nonEmpty) { + val aliasMap = AttributeMap(duplicates.map { dup => + dup -> Alias(dup, dup.toString)() + }.toSeq) + val aliasedExpressions = innerPlan.output.map { ref => + aliasMap.getOrElse(ref, ref) + } + val aliasedProjection = Project(aliasedExpressions, innerPlan) + val aliasedConditions = conditions.map(_.transform { + case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute + }) + (aliasedProjection, aliasedConditions) + } else { + (innerPlan, conditions) + } + } + + /** + * Build a mapping between domain attributes and corresponding outer query expressions + * using the join conditions. + */ + private def buildDomainAttrMap( + conditions: Seq[Expression], + domainAttrs: Seq[Attribute]): Map[Attribute, Expression] = { + val domainAttrSet = AttributeSet(domainAttrs) + conditions.collect { + // When we build the join conditions between the domain attributes and outer references, + // the left hand side is always the domain attribute used in the inner query and the right + // hand side is the attribute from the outer query. Note here the right hand side of a + // condition is not necessarily an attribute, for example it can be a literal (if foldable) + // or a cast expression after the optimization. + case EqualNullSafe(left: Attribute, right: Expression) if domainAttrSet.contains(left) => + left -> right + }.toMap + } + + /** + * Rewrite all [[DomainJoin]]s in the inner query to actual inner joins with the outer query. + */ + def rewriteDomainJoins( + outerPlan: LogicalPlan, + innerPlan: LogicalPlan, + conditions: Seq[Expression]): LogicalPlan = { + innerPlan transform { + case d @ DomainJoin(domainAttrs, child) => + val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs) + // We should only rewrite a domain join when all corresponding outer plan attributes + // can be found from the join condition. + if (domainAttrMap.size == domainAttrs.size) { + val groupingExprs = domainAttrs.map(domainAttrMap) + val aggregateExprs = groupingExprs.zip(domainAttrs).map { + // Rebuild the aliases. + case (inputAttr, outputAttr) => Alias(inputAttr, outputAttr.name)(outputAttr.exprId) + } + // Construct a domain with the outer query plan. + // DomainJoin [a', b'] => Aggregate [a, b] [a AS a', b AS b'] + // +- Relation [a, b] + val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan) + child match { + // A special optimization for OneRowRelation. + // TODO: add a more general rule to optimize join with OneRowRelation. + case _: OneRowRelation => domain + // Construct a domain join. + // Join Inner + // :- Inner Query + // +- Domain + case _ => Join(child, domain, Inner, None, JoinHint.NONE) + } + } else { + throw new UnsupportedOperationException( + s"Unable to rewrite domain join with conditions: $conditions\n$d") + } + } + } + + def apply( + innerPlan: LogicalPlan, + outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = { + apply(innerPlan, Seq(outerPlan)) + } + + def apply( + innerPlan: LogicalPlan, + outerPlans: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + val outputSet = AttributeSet(outerPlans.flatMap(_.outputSet)) + + // The return type of the recursion. + // The first parameter is a new logical plan with correlation eliminated. + // The second parameter is a list of join conditions with the outer query. + // The third parameter is a mapping between the outer references and equivalent + // expressions from the inner query that is used to replace outer references. + type ReturnType = (LogicalPlan, Seq[Expression], Map[Attribute, Attribute]) + + // Decorrelate the input plan with a set of parent outer references and a boolean flag + // indicating whether the result of the plan will be aggregated. Steps: + // 1. Recursively collects outer references from the inner query until it reaches a node + // that does not contain correlated value. + // 2. Inserts an optional [[DomainJoin]] node to indicate whether a domain (inner) join is + // needed between the outer query and the specific sub-tree of the inner query. + // 3. Returns a list of join conditions with the outer query and a mapping between outer + // references with references inside the inner query. The parent nodes need to preserve + // the references inside the join conditions and substitute all outer references using + // the mapping. + def decorrelate( + plan: LogicalPlan, + parentOuterReferences: AttributeSet, + aggregated: Boolean = false): ReturnType = { + val isCorrelated = hasOuterReferences(plan) + if (!isCorrelated) { + // We have reached a plan without correlation to the outer plan. + if (parentOuterReferences.isEmpty) { + // If there is no outer references from the parent nodes, it means all outer + // attributes can be substituted by attributes from the inner plan. So no + // domain join is needed. + (plan, Nil, Map.empty[Attribute, Attribute]) + } else { + // Build the domain join with the parent outer references. + val attributes = parentOuterReferences.toSeq + val domains = attributes.map(_.newInstance()) + // A placeholder to be rewritten into domain join. + val domainJoin = DomainJoin(domains, plan) + val outerReferenceMap = attributes.zip(domains).toMap + // Build join conditions between domain attributes and outer references. + // EqualNullSafe is used to make sure null key can be joined together. Note + // outer referenced attributes can be changed during the outer query optimization. + // The equality conditions will also serve as an attribute mapping between new + // outer references and domain attributes when rewriting the domain joins. + // E.g. if the attribute a is changed to a1, the join condition a' <=> outer(a) + // will become a' <=> a1, and we can construct the aliases based on the condition: + // DomainJoin [a'] Join Inner + // +- InnerQuery => :- InnerQuery + // +- Aggregate [a1] [a1 AS a'] + // +- OuterQuery + val conditions = outerReferenceMap.map { + case (o, a) => EqualNullSafe(a, OuterReference(o)) + } + (domainJoin, conditions.toSeq, outerReferenceMap) + } + } else { + plan match { + case Filter(condition, child) => + val conditions = splitConjunctivePredicates(condition) + val (correlated, uncorrelated) = conditions.partition(containsOuter) + // Find outer references that can be substituted by attributes from the inner + // query using the equality predicates. + val equivalences = collectEquivalentOuterReferences(correlated) + // Correlated predicates can be removed from the Filter's condition and used as + // join conditions with the outer query. However, if the results of the sub-tree + // is aggregated, only certain correlated equality predicates can be used, because + // the references in the join conditions need to be preserved in both the grouping + // and aggregate expressions of an Aggregate, which may change the semantics of the + // plan and lead to incorrect results. Here is an example: + // Relations: + // t1(a, b): [(1, 1)] + // t2(c, d): [(1, 1), (2, 0)] + // + // Query: + // SELECT * FROM t1 WHERE a = (SELECT MAX(c) FROM t2 WHERE b >= d) + // + // Subquery plan transformation if correlated predicates are used as join conditions: + // Aggregate [max(c)] Aggregate [d] [max(c), d] + // +- Filter (outer(b) >= d) => +- Relation [c, d] + // +- Relation [c, d] + // + // Plan after rewrite: + // Project [a, b] -- [(1, 1)] + // +- Join LeftOuter (b >= d AND a = max(c)) + // :- Relation [a, b] + // +- Aggregate [d] [max(c), d] -- [(1, 1), (2, 0)] + // +- Relation [c, d] + // + // The result of the original query should be an empty set but the transformed + // query will output an incorrect result of (1, 1). The correct transformation + // with domain join is illustrated below: + // Aggregate [max(c)] Aggregate [b'] [max(c), b'] + // +- Filter (outer(b) >= d) => +- Filter (b' >= d) + // +- Relation [c, d] +- DomainJoin [b'] + // +- Relation [c, d] + // Plan after rewrite: + // Project [a, b] + // +- Join LeftOuter (b <=> b' AND a = max(c)) -- [] + // :- Relation [a, b] + // +- Aggregate [b'] [max(c), b'] -- [(2, 1)] + // +- Join Inner (b' >= d) -- [(1, 1, 1), (2, 0, 1)] (DomainJoin) + // :- Relation [c, d] + // +- Aggregate [b] [b AS b'] -- [(1)] (Domain) + // +- Relation [a, b] + if (aggregated) { + // Split the correlated predicates into predicates that can and cannot be directly + // used as join conditions with the outer query depending on whether they can + // be pulled up over an Aggregate without changing the semantics of the plan. + val (equalityCond, predicates) = correlated.partition(canPullUpOverAgg) + val outerReferences = collectOuterReferences(predicates) + val newOuterReferences = + parentOuterReferences ++ outerReferences -- equivalences.keySet + val (newChild, joinCond, outerReferenceMap) = + decorrelate(child, newOuterReferences, aggregated) + // Add the outer references mapping collected from the equality conditions. + val newOuterReferenceMap = outerReferenceMap ++ equivalences + // Replace all outer references in the non-equality predicates. + val newCorrelated = replaceOuterReferences(predicates, newOuterReferenceMap) + // The new filter condition is the original filter condition with correlated + // equality predicates removed. + val newFilterCond = newCorrelated ++ uncorrelated + val newFilter = newFilterCond match { + case Nil => newChild + case conditions => Filter(conditions.reduce(And), newChild) + } + // Equality predicates are used as join conditions with the outer query. + val newJoinCond = joinCond ++ equalityCond + (newFilter, newJoinCond, newOuterReferenceMap) + } else { + // Results of this sub-tree is not aggregated, so all correlated predicates + // can be directly used as outer query join conditions. + val newOuterReferences = parentOuterReferences -- equivalences.keySet + val (newChild, joinCond, outerReferenceMap) = + decorrelate(child, newOuterReferences, aggregated) + // Add the outer references mapping collected from the equality conditions. + val newOuterReferenceMap = outerReferenceMap ++ equivalences + val newFilter = uncorrelated match { + case Nil => newChild + case conditions => Filter(conditions.reduce(And), newChild) + } + val newJoinCond = joinCond ++ correlated + (newFilter, newJoinCond, newOuterReferenceMap) + } + + case Project(projectList, child) => + val outerReferences = collectOuterReferences(projectList) + val newOuterReferences = parentOuterReferences ++ outerReferences + val (newChild, joinCond, outerReferenceMap) = + decorrelate(child, newOuterReferences, aggregated) + // Replace all outer references in the original project list. + val newProjectList = replaceOuterReferences(projectList, outerReferenceMap) + // Preserve required domain attributes in the join condition by adding the missing + // references to the new project list. + val referencesToAdd = missingReferences(newProjectList, joinCond) + val newProject = Project(newProjectList ++ referencesToAdd, newChild) + (newProject, joinCond, outerReferenceMap) + + case a @ Aggregate(groupingExpressions, aggregateExpressions, child) => + val outerReferences = collectOuterReferences(a.expressions) + val newOuterReferences = parentOuterReferences ++ outerReferences + val (newChild, joinCond, outerReferenceMap) = + decorrelate(child, newOuterReferences, aggregated = true) + // Replace all outer references in grouping and aggregate expressions. + val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap) + val newAggExpr = replaceOuterReferences(aggregateExpressions, outerReferenceMap) + // Add all required domain attributes to both grouping and aggregate expressions. + val referencesToAdd = missingReferences(newAggExpr, joinCond) + val newAggregate = a.copy( + groupingExpressions = newGroupingExpr ++ referencesToAdd, + aggregateExpressions = newAggExpr ++ referencesToAdd, + child = newChild) + (newAggregate, joinCond, outerReferenceMap) + + case j @ Join(left, right, joinType, condition, _) => + val outerReferences = collectOuterReferences(j.expressions) + // Join condition containing outer references is not supported. + assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j") + val newOuterReferences = parentOuterReferences ++ outerReferences + val shouldPushToLeft = joinType match { + case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true + case _ => hasOuterReferences(left) + } + val shouldPushToRight = joinType match { + case RightOuter | FullOuter => true + case _ => hasOuterReferences(right) + } + val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) { + decorrelate(left, newOuterReferences, aggregated) + } else { + (left, Nil, Map.empty[Attribute, Attribute]) + } + val (newRight, rightJoinCond, rightOuterReferenceMap) = if (shouldPushToRight) { + decorrelate(right, newOuterReferences, aggregated) + } else { + (right, Nil, Map.empty[Attribute, Attribute]) + } + val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap + val newJoinCond = leftJoinCond ++ rightJoinCond + // If we push the dependent join to both sides, we can augment the join condition + // such that both sides are matched on the domain attributes. For example, + // - Left Map: {outer(c1) = c1} + // - Right Map: {outer(c1) = 10 - c1} + // Then the join condition can be augmented with (c1 <=> 10 - c1). + val augmentedConditions = leftOuterReferenceMap.flatMap { + case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _)) + } + val newCondition = (condition ++ augmentedConditions).reduceOption(And) + val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition) + (newJoin, newJoinCond, newOuterReferenceMap) + + case u: UnaryNode => + val outerReferences = collectOuterReferences(u.expressions) + assert(outerReferences.isEmpty, s"Correlated column is not allowed in $u") + decorrelate(u.child, parentOuterReferences, aggregated) + + case o => + throw new UnsupportedOperationException( + s"Decorrelate inner query through ${o.nodeName} is not supported.") + } + } + } + val (newChild, joinCond, _) = decorrelate(BooleanSimplification(innerPlan), AttributeSet.empty) + val (plan, conditions) = deduplicate(newChild, joinCond, outputSet) + (plan, stripOuterReferences(conditions)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index ef73e58645a89..9381796d3d06b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /* @@ -272,22 +273,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper val baseConditions = predicateMap.values.flatten.toSeq val (newPlan, newCond) = if (outer.nonEmpty) { val outputSet = outer.map(_.outputSet).reduce(_ ++ _) - val duplicates = transformed.outputSet.intersect(outputSet) - val (plan, deDuplicatedConditions) = if (duplicates.nonEmpty) { - val aliasMap = AttributeMap(duplicates.map { dup => - dup -> Alias(dup, dup.toString)() - }.toSeq) - val aliasedExpressions = transformed.output.map { ref => - aliasMap.getOrElse(ref, ref) - } - val aliasedProjection = Project(aliasedExpressions, transformed) - val aliasedConditions = baseConditions.map(_.transform { - case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute - }) - (aliasedProjection, aliasedConditions) - } else { - (transformed, baseConditions) - } + val (plan, deDuplicatedConditions) = + DecorrelateInnerQuery.deduplicate(transformed, baseConditions, outputSet) (plan, stripOuterReferences(deDuplicatedConditions)) } else { (transformed, stripOuterReferences(baseConditions)) @@ -308,9 +295,17 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper if (newCond.isEmpty) oldCond else newCond } + def decorrelate(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = { + if (SQLConf.get.decorrelateInnerQueryEnabled) { + DecorrelateInnerQuery(sub, outer) + } else { + pullOutCorrelatedPredicates(sub, outer) + } + } + plan transformExpressions { case ScalarSubquery(sub, children, exprId) if children.nonEmpty => - val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) + val (newPlan, newCond) = decorrelate(sub, outerPlans) ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId) case Exists(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans) @@ -379,7 +374,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe bindings: Map[ExprId, Expression]): Expression = { val rewrittenExpr = expr transform { case r: AttributeReference => - bindings.getOrElse(r.exprId, Literal.default(NullType)) + bindings.getOrElse(r.exprId, Literal.create(null, r.dataType)) } tryEvalExpr(rewrittenExpr) @@ -394,9 +389,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe // Also replace attribute refs (for example, for grouping columns) with NULL. val rewrittenExpr = expr transform { case a @ AggregateExpression(aggFunc, _, _, resultId, _) => - aggFunc.defaultResult.getOrElse(Literal.default(NullType)) + aggFunc.defaultResult.getOrElse(Literal.create(null, aggFunc.dataType)) - case _: AttributeReference => Literal.default(NullType) + case a: AttributeReference => Literal.create(null, a.dataType) } tryEvalExpr(rewrittenExpr) @@ -525,7 +520,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = { val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { - case (currentChild, ScalarSubquery(query, conditions, _)) => + case (currentChild, ScalarSubquery(sub, conditions, _)) => + val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head val resultWithZeroTups = evalSubqueryOnZeroTups(query) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 21e87b4c62606..49e3e3c9ee999 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1393,3 +1393,14 @@ case class CollectMetrics( override protected def withNewChildInternal(newChild: LogicalPlan): CollectMetrics = copy(child = newChild) } + +/** + * A placeholder for domain join that can be added when decorrelating subqueries. + * It should be rewritten during the optimization phase. + */ +case class DomainJoin(domainAttrs: Seq[Attribute], child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output ++ domainAttrs + override def producedAttributes: AttributeSet = AttributeSet(domainAttrs) + override protected def withNewChildInternal(newChild: LogicalPlan): DomainJoin = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f4c236c68dfe9..04e740039f005 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2427,6 +2427,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DECORRELATE_INNER_QUERY_ENABLED = + buildConf("spark.sql.optimizer.decorrelateInnerQuery.enabled") + .internal() + .doc("Decorrelate inner query by eliminating correlated references and build domain joins.") + .version("3.2.0") + .booleanConf + .createWithDefault(true) + val TOP_K_SORT_FALLBACK_THRESHOLD = buildConf("spark.sql.execution.topKSortFallbackThreshold") .internal() @@ -3829,6 +3837,8 @@ class SQLConf extends Serializable with Logging { def legacyIntervalEnabled: Boolean = getConf(LEGACY_INTERVAL_ENABLED) + def decorrelateInnerQueryEnabled: Boolean = getConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala new file mode 100644 index 0000000000000..93b27035aca33 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType + +class DecorrelateInnerQuerySuite extends PlanTest { + + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val c = AttributeReference("c", IntegerType)() + val x = AttributeReference("x", IntegerType)() + val y = AttributeReference("y", IntegerType)() + val z = AttributeReference("z", IntegerType)() + val testRelation = LocalRelation(a, b, c) + val testRelation2 = LocalRelation(x, y, z) + + private def hasOuterReferences(plan: LogicalPlan): Boolean = { + plan.find(_.expressions.exists(SubExprUtils.containsOuter)).isDefined + } + + private def check( + innerPlan: LogicalPlan, + outerPlan: LogicalPlan, + correctAnswer: LogicalPlan, + conditions: Seq[Expression]): Unit = { + val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan) + assert(!hasOuterReferences(outputPlan)) + comparePlans(outputPlan, correctAnswer) + assert(joinCond.length == conditions.length) + joinCond.zip(conditions).foreach(e => compareExpressions(e._1, e._2)) + } + + test("filter with correlated equality predicates only") { + val outerPlan = testRelation2 + val innerPlan = + Project(Seq(a, b), + Filter(OuterReference(x) === a, + testRelation)) + val correctAnswer = Project(Seq(a, b), testRelation) + check(innerPlan, outerPlan, correctAnswer, Seq(x === a)) + } + + test("filter with local and correlated equality predicates") { + val outerPlan = testRelation2 + val innerPlan = + Project(Seq(a, b), + Filter(And(OuterReference(x) === a, b === 3), + testRelation)) + val correctAnswer = + Project(Seq(a, b), + Filter(b === 3, + testRelation)) + check(innerPlan, outerPlan, correctAnswer, Seq(x === a)) + } + + test("filter with correlated non-equality predicates") { + val outerPlan = testRelation2 + val innerPlan = + Project(Seq(a, b), + Filter(OuterReference(x) > a, + testRelation)) + val correctAnswer = Project(Seq(a, b), testRelation) + check(innerPlan, outerPlan, correctAnswer, Seq(x > a)) + } + + test("duplicated output attributes") { + val outerPlan = testRelation + val innerPlan = + Project(Seq(a), + Filter(OuterReference(a) === a, + testRelation)) + val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan) + val a1 = outputPlan.output.head + val correctAnswer = + Project(Seq(Alias(a, a1.name)(a1.exprId)), + Project(Seq(a), + testRelation)) + comparePlans(outputPlan, correctAnswer) + assert(joinCond == Seq(a === a1)) + } + + test("filter with equality predicates with correlated values on both sides") { + val outerPlan = testRelation2 + val innerPlan = + Project(Seq(a), + Filter(OuterReference(x) === OuterReference(y) + b, + testRelation)) + val correctAnswer = Project(Seq(a, b), testRelation) + check(innerPlan, outerPlan, correctAnswer, Seq(x === y + b)) + } + + test("aggregate with correlated equality predicates that can be pulled up") { + val outerPlan = testRelation2 + val minB = Alias(min(b), "min_b")() + val innerPlan = + Aggregate(Nil, Seq(minB), + Filter(And(OuterReference(x) === a, b === 3), + testRelation)) + val correctAnswer = + Aggregate(Seq(a), Seq(minB, a), + Filter(b === 3, + testRelation)) + check(innerPlan, outerPlan, correctAnswer, Seq(x === a)) + } + + test("aggregate with correlated equality predicates that cannot be pulled up") { + val outerPlan = testRelation2 + val minB = Alias(min(b), "min_b")() + val innerPlan = + Aggregate(Nil, Seq(minB), + Filter(OuterReference(x) === OuterReference(y) + a, + testRelation)) + val correctAnswer = + Aggregate(Seq(x, y), Seq(minB, x, y), + Filter(x === y + a, + DomainJoin(Seq(x, y), testRelation))) + check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) + } + + test("aggregate with correlated equality predicates that has no attribute") { + val outerPlan = testRelation2 + val minB = Alias(min(b), "min_b")() + val innerPlan = + Aggregate(Nil, Seq(minB), + Filter(OuterReference(x) === OuterReference(y), + testRelation)) + val correctAnswer = + Aggregate(Nil, Seq(minB), + testRelation) + check(innerPlan, outerPlan, correctAnswer, Seq(x === y)) + } + + test("aggregate with correlated non-equality predicates") { + val outerPlan = testRelation2 + val minB = Alias(min(b), "min_b")() + val innerPlan = + Aggregate(Nil, Seq(minB), + Filter(OuterReference(x) > a, + testRelation)) + val correctAnswer = + Aggregate(Seq(x), Seq(minB, x), + Filter(x > a, + DomainJoin(Seq(x), testRelation))) + check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x)) + } + + test("join with correlated equality predicates") { + val outerPlan = testRelation2 + val joinCondition = Some($"t1.b" === $"t2.b") + val left = + Project(Seq(b), + Filter(OuterReference(x) === b, + testRelation)).as("t1") + val right = + Project(Seq(b), + Filter(OuterReference(x) === a, + testRelation)).as("t2") + Seq(Inner, LeftOuter, LeftSemi, LeftAnti, RightOuter, FullOuter, Cross).foreach { joinType => + val innerPlan = Join(left, right, joinType, joinCondition, JoinHint.NONE).analyze + val newLeft = Project(Seq(b), testRelation).as("t1") + val newRight = Project(Seq(b, a), testRelation).as("t2") + // Since the left-hand side has outer(x) = b, and the right-hand side has outer(x) = a, the + // join condition will be augmented with b <=> a. + val newCond = Some(And($"t1.b" <=> $"t2.a", $"t1.b" === $"t2.b")) + val correctAnswer = Join(newLeft, newRight, joinType, newCond, JoinHint.NONE).analyze + check(innerPlan, outerPlan, correctAnswer, Seq(x === b, x === a)) + } + } + + test("correlated values inside join condition") { + val outerPlan = testRelation2 + val innerPlan = + Join( + testRelation.as("t1"), + Filter(OuterReference(y) === 3, testRelation), + Inner, + Some(OuterReference(x) === a), + JoinHint.NONE) + val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, outerPlan) } + assert(error.getMessage.contains("Correlated column is not allowed in join")) + } + + test("correlated values in project") { + val outerPlan = testRelation2 + val innerPlan = Project(Seq(OuterReference(x), OuterReference(y)), OneRowRelation()) + val correctAnswer = Project(Seq(x, y), DomainJoin(Seq(x, y), OneRowRelation())) + check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) + } + + test("correlated values in project with alias") { + val outerPlan = testRelation2 + val innerPlan = + Project(Seq(OuterReference(x), 'y1, 'sum), + Project(Seq( + OuterReference(x), + OuterReference(y).as("y1"), + Add(OuterReference(x), OuterReference(y)).as("sum")), + testRelation)).analyze + val correctAnswer = + Project(Seq(x, 'y1, 'sum, y), + Project(Seq(x, y.as("y1"), (x + y).as("sum"), y), + DomainJoin(Seq(x, y), testRelation))).analyze + check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y)) + } + + test("correlated values in project with correlated equality conditions in filter") { + val outerPlan = testRelation2 + val innerPlan = + Project( + Seq(OuterReference(x)), + Filter( + And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)), + testRelation + ) + ) + val correctAnswer = Project(Seq(a, c), Filter(b === 1, testRelation)) + check(innerPlan, outerPlan, correctAnswer, Seq(x === a, x + y === c)) + } + + test("correlated values in project without correlated equality conditions in filter") { + val outerPlan = testRelation2 + val innerPlan = + Project( + Seq(OuterReference(y)), + Filter( + And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)), + testRelation + ) + ) + val correctAnswer = + Project(Seq(y, a, c), + Filter(b === 1, + DomainJoin(Seq(y), testRelation) + ) + ) + check(innerPlan, outerPlan, correctAnswer, Seq(y <=> y, x === a, x + y === c)) + } + + test("correlated values in project with aggregate") { + val outerPlan = testRelation2 + val innerPlan = + Aggregate( + Seq('x1), Seq(min('y1).as("min_y1")), + Project( + Seq(a, OuterReference(x).as("x1"), OuterReference(y).as("y1")), + Filter( + And(OuterReference(x) === a, OuterReference(y) === OuterReference(z)), + testRelation + ) + ) + ).analyze + val correctAnswer = + Aggregate( + Seq('x1, y, a), Seq(min('y1).as("min_y1"), y, a), + Project( + Seq(a, a.as("x1"), y.as("y1"), y), + DomainJoin(Seq(y), testRelation) + ) + ).analyze + check(innerPlan, outerPlan, correctAnswer, Seq(y <=> y, x === a, y === z)) + } +}