diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 12e21faca9f2..9301abc9f225 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -210,6 +210,8 @@ abstract class Optimizer(catalogManager: CatalogManager) // idempotence enforcement on this batch. We thus make it FixedPoint(1) instead of Once. Batch("Join Reorder", FixedPoint(1), CostBasedJoinReorder) :+ + Batch("Pull Out Complex Join Keys", Once, + PullOutComplexJoinKeys) :+ Batch("Eliminate Sorts", Once, EliminateSorts) :+ Batch("Decimal Optimizations", fixedPoint, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeys.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeys.scala new file mode 100644 index 000000000000..9280f2a71382 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeys.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Alias, And, EqualTo, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN + +/** + * This rule pulls out the complex join keys expression if can not broadcast. + * Example: + * + * +- Join Inner, ((c1 % 2) = c2)) - Project [c1, c2] + * :- Relation default.t1[c1] parquet +- Join Inner, (_complexjoinkey_0 = c2)) + * +- Relation default.t2[c2] parquet => :- Project [c1, (c1 % 2) AS _complexjoinkey_0] + * : +- Relation default.t1[c1] parquet + * +- Relation default.t2[c2] parquet + * + * For shuffle based join, we may evaluate the join keys for several times: + * - SMJ: always evaluate the join keys during join, and probably evaluate if has shuffle or sort + * - SHJ: always evaluate the join keys during join, and probably evaluate if has shuffle + * So this rule can reduce the cost of repetitive evaluation. + */ +object PullOutComplexJoinKeys extends Rule[LogicalPlan] with JoinSelectionHelper { + + private def isComplexExpression(e: Expression): Boolean = + e.deterministic && !e.foldable && e.children.nonEmpty + + private def hasComplexExpression(joinKeys: Seq[Expression]): Boolean = + joinKeys.exists(isComplexExpression) + + private def extractComplexExpression( + joinKeys: Seq[Expression], + startIndex: Int): mutable.LinkedHashMap[Expression, NamedExpression] = { + val map = new mutable.LinkedHashMap[Expression, NamedExpression]() + var i = startIndex + joinKeys.foreach { + case e: Expression if isComplexExpression(e) => + map.put(e.canonicalized, Alias(e, s"_complexjoinkey_$i")()) + i += 1 + case _ => + } + map + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.transformWithPruning(_.containsPattern(JOIN), ruleId) { + case j @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, other, _, left, right, joinHint) + if hasComplexExpression(leftKeys) || hasComplexExpression(rightKeys) => + val leftComplexExprs = extractComplexExpression(leftKeys, 0) + val (newLeftKeys, newLeft) = + if ((!canBuildBroadcastLeft(joinType) || !canBroadcastBySize(left, conf)) && + leftComplexExprs.nonEmpty) { + ( + leftKeys.map { e => + if (leftComplexExprs.contains(e.canonicalized)) { + leftComplexExprs(e.canonicalized).toAttribute + } else { + e + } + }, + Project(left.output ++ leftComplexExprs.values.toSeq, left) + ) + } else { + (leftKeys, left) + } + + val rightComplexExprs = extractComplexExpression(rightKeys, leftComplexExprs.size) + val (newRightKeys, newRight) = + if ((!canBuildBroadcastRight(joinType) || !canBroadcastBySize(right, conf)) && + rightComplexExprs.nonEmpty) { + ( + rightKeys.map { e => + if (rightComplexExprs.contains(e.canonicalized)) { + rightComplexExprs(e.canonicalized).toAttribute + } else { + e + } + }, + Project(right.output ++ rightComplexExprs.values.toSeq, right) + ) + } else { + (rightKeys, right) + } + + if (left.eq(newLeft) && right.eq(newRight)) { + j + } else { + val newConditions = newLeftKeys.zip(newRightKeys).map { + case (l, r) => EqualTo(l, r) + } ++ other + + Project( + j.output, + Join(newLeft, newRight, joinType, newConditions.reduceOption(And), joinHint)) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 1204fa8c604a..e8ac08cc922d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -129,6 +129,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields":: "org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelation" :: "org.apache.spark.sql.catalyst.optimizer.PruneFilters" :: + "org.apache.spark.sql.catalyst.optimizer.PullOutComplexJoinKeys" :: "org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" :: "org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" :: "org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeysSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeysSuite.scala new file mode 100644 index 000000000000..ed82c8e4ed67 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutComplexJoinKeysSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf + +class PullOutComplexJoinKeysSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("PullOutComplexJoinKeys", FixedPoint(1), + PullOutComplexJoinKeys, + CollapseProject) :: Nil + } + + val testRelation1 = LocalRelation($"a".int, $"b".int) + val testRelation2 = LocalRelation($"x".int, $"y".int) + + test("pull out complex join keys") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + // join + // a (complex join key) + // b + val plan1 = testRelation1.join(testRelation2, condition = Some($"a" % 2 === $"x")) + val expected1 = testRelation1.select($"a", $"b", ($"a" % 2) as "_complexjoinkey_0").join( + testRelation2, condition = Some($"_complexjoinkey_0" === $"x")) + .select($"a", $"b", $"x", $"y") + comparePlans(Optimize.execute(plan1.analyze), expected1.analyze) + + // join + // project + // a (complex join key) + // b + val plan2 = testRelation1.select($"a").join( + testRelation2, condition = Some($"a" % 2 === $"x")) + val expected2 = testRelation1.select($"a", ($"a" % 2) as "_complexjoinkey_0") + .join(testRelation2, condition = Some($"_complexjoinkey_0" === $"x")) + .select($"a", $"x", $"y") + comparePlans(Optimize.execute(plan2.analyze), expected2.analyze) + + // join + // a (two complex join keys) + // b + val plan3 = testRelation1.join(testRelation2, + condition = Some($"a" % 2 === $"x" && $"b" % 3 === $"y")) + val expected3 = testRelation1.select($"a", $"b", ($"a" % 2) as "_complexjoinkey_0", + ($"b" % 3) as "_complexjoinkey_1").join(testRelation2, + condition = Some($"_complexjoinkey_0" === $"x" && $"_complexjoinkey_1" === $"y")) + .select($"a", $"b", $"x", $"y") + comparePlans(Optimize.execute(plan3.analyze), expected3.analyze) + + // join + // a + // b (complex join key) + val plan4 = testRelation1.join(testRelation2, condition = Some($"a" === $"x" % 2)) + val expected4 = testRelation1.join(testRelation2.select($"x", $"y", + ($"x" % 2) as "_complexjoinkey_0"), condition = Some($"a" === $"_complexjoinkey_0")) + .select($"a", $"b", $"x", $"y") + comparePlans(Optimize.execute(plan4.analyze), expected4.analyze) + + // join + // a (complex join key) + // b (complex join key) + val plan5 = testRelation1.join(testRelation2, condition = Some($"a" % 2 === $"x" % 3)) + val expected5 = testRelation1.select($"a", $"b", ($"a" % 2) as "_complexjoinkey_0").join( + testRelation2.select($"x", $"y", ($"x" % 3) as "_complexjoinkey_1"), + condition = Some($"_complexjoinkey_0" === $"_complexjoinkey_1")) + .select($"a", $"b", $"x", $"y") + comparePlans(Optimize.execute(plan5.analyze), expected5.analyze) + } + } + + test("do not pull out complex join keys") { + // can broadcast + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100000") { + val p1 = testRelation1.join(testRelation2, condition = Some($"a" % 2 === $"x")).analyze + comparePlans(Optimize.execute(p1), p1) + + val p2 = testRelation1.join(testRelation2, condition = Some($"a" === $"x" % 2)).analyze + comparePlans(Optimize.execute(p2), p2) + } + + // not contains complex expression + val p1 = testRelation1.subquery("t1").join( + testRelation2.subquery("t2"), condition = Some($"a" === $"x")) + comparePlans(Optimize.execute(p1.analyze), p1.analyze) + + // not a equi-join + val p2 = testRelation1.subquery("t1").join(testRelation2.subquery("t2")) + comparePlans(Optimize.execute(p2.analyze), p2.analyze) + } +}