diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 5e1046293a20..ce98a3fc6fb6 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -338,7 +338,7 @@ querySpecification (RECORDREADER recordReader=STRING)? fromClause? (WHERE where=booleanExpression)?) - | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause? + | ((kind=SELECT hint? setQuantifier? namedExpressionSeq fromClause? | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) lateralView* (WHERE where=booleanExpression)? @@ -347,6 +347,16 @@ querySpecification windows?) ; +hint + : '/*+' hintStatement '*/' + ; + +hintStatement + : hintName=identifier + | hintName=identifier '(' parameters+=identifier parameters+=identifier ')' + | hintName=identifier '(' parameters+=identifier (',' parameters+=identifier)* ')' + ; + fromClause : FROM relation (',' relation)* lateralView* ; @@ -945,8 +955,12 @@ SIMPLE_COMMENT : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) ; +BRACKETED_EMPTY_COMMENT + : '/**/' -> channel(HIDDEN) + ; + BRACKETED_COMMENT - : '/*' .*? '*/' -> channel(HIDDEN) + : '/*' ~[+] .*? '*/' -> channel(HIDDEN) ; WS diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2efa997ff22d..350f87b002ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -84,7 +84,8 @@ class Analyzer( Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, - EliminateUnions), + EliminateUnions, + SubstituteHints), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: @@ -1786,6 +1787,63 @@ class Analyzer( } } + /** + * Substitute Hints. + * - BROADCAST/BROADCASTJOIN/MAPJOIN match the closest table with the given name parameters. + * + * This rule substitutes `UnresolvedRelation`s in `Substitute` batch before `ResolveRelations` + * rule is applied. Here are two reasons. + * - To support `MetastoreRelation` in Hive module. + * - To reduce the effect of `Hint` on the other rules. + * + * After this rule, it is guaranteed that there exists no unknown `Hint` in the plan. + * All new `Hint`s should be transformed into concrete Hint classes `BroadcastHint` here. + */ + object SubstituteHints extends Rule[LogicalPlan] { + val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") + + import scala.collection.mutable.Set + private def appendAllDescendant(set: Set[LogicalPlan], plan: LogicalPlan): Unit = { + set += plan + plan.children.foreach { child => appendAllDescendant(set, child) } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformDown { + case h @ Hint(name, parameters, child) if BROADCAST_HINT_NAMES.contains(name.toUpperCase) => + var resolvedChild = child + for (table <- parameters) { + var stop = false + val skipNodeSet = scala.collection.mutable.Set.empty[LogicalPlan] + resolvedChild = resolvedChild.transformDown { + case n if skipNodeSet.contains(n) => + skipNodeSet -= n + n + case p @ Project(_, _) if p != resolvedChild => + appendAllDescendant(skipNodeSet, p) + skipNodeSet -= p + p + case r @ BroadcastHint(UnresolvedRelation(t, _)) + if !stop && resolver(t.table, table) => + stop = true + r + case r @ UnresolvedRelation(t, alias) if !stop && resolver(t.table, table) => + stop = true + if (alias.isDefined) { + SubqueryAlias(alias.get, BroadcastHint(r.copy(alias = None))) + } else { + BroadcastHint(r) + } + } + } + resolvedChild + + // Remove unrecognized hints + case Hint(name, _, child) => child + } + } + } + /** * Check and add proper window frames for all window functions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 8b87a4e41c23..286320e6163c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -372,6 +372,10 @@ trait CheckAnalysis extends PredicateHelper { |in operator ${operator.simpleString} """.stripMargin) + case Hint(_, _, _) => + throw new IllegalStateException( + "logical hint operator should have been removed by analyzer") + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f2cc8d362478..13aa45772f2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -377,7 +377,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } // Window - withDistinct.optionalMap(windows)(withWindows) + val withWindow = withDistinct.optionalMap(windows)(withWindows) + + // Hint + withWindow.optionalMap(ctx.hint)(withHints) } } @@ -508,6 +511,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } + /** + * Add a Hint to a logical plan. + */ + private def withHints( + ctx: HintContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val stmt = ctx.hintStatement + Hint(stmt.hintName.getText, stmt.parameters.asScala.map(_.getText), query) + } + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b31f5aa11c22..35f80d6a6f46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -354,6 +354,15 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override lazy val statistics: Statistics = super.statistics.copy(isBroadcastable = true) } +/** + * A general hint for the child. + * A pair of (name, parameters). + */ +case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) extends UnaryNode { + override lazy val resolved: Boolean = false + override def output: Seq[Attribute] = child.output +} + case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 3acb261800c0..0f059b959146 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -32,6 +32,7 @@ trait AnalysisTest extends PlanTest { val conf = new SimpleCatalystConf(caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) + catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) new Analyzer(catalog, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala new file mode 100644 index 000000000000..64e85111c43d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ + +class SubstituteHintsSuite extends AnalysisTest { + import org.apache.spark.sql.catalyst.analysis.TestRelations._ + + val a = testRelation.output(0) + val b = testRelation2.output(0) + + test("case-sensitive or insensitive parameters") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), + BroadcastHint(testRelation), + caseSensitive = false) + + checkAnalysis( + Hint("MAPJOIN", Seq("table"), table("TaBlE")), + BroadcastHint(testRelation), + caseSensitive = false) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), + BroadcastHint(testRelation)) + + checkAnalysis( + Hint("MAPJOIN", Seq("table"), table("TaBlE")), + testRelation) + } + + test("single hint") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").select(a)), + BroadcastHint(testRelation).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), + BroadcastHint(testRelation).join(testRelation2).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE2"), + table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), + testRelation.join(BroadcastHint(testRelation2)).select(a)) + } + + test("single hint with multiple parameters") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE", "TaBlE"), + table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), + BroadcastHint(testRelation).join(testRelation2).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE", "TaBlE2"), + table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)), + BroadcastHint(testRelation).join(BroadcastHint(testRelation2)).select(a)) + } + + test("duplicated nested hints are transformed into one") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").select('a)) + .join(table("TaBlE2").as("u")).select(a)), + BroadcastHint(testRelation).select(a).join(testRelation2).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE2"), + table("TaBlE").as("t").select(a) + .join(Hint("MAPJOIN", Seq("TaBlE2"), table("TaBlE2").as("u").select(b))).select(a)), + testRelation.select(a).join(BroadcastHint(testRelation2).select(b)).select(a)) + } + + test("distinct nested two hints are handled separately") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE2"), + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").select(a)) + .join(table("TaBlE2").as("u")).select(a)), + BroadcastHint(testRelation).select(a).join(BroadcastHint(testRelation2)).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), + table("TaBlE").as("t") + .join(Hint("MAPJOIN", Seq("TaBlE2"), table("TaBlE2").as("u").select(b))).select(a)), + BroadcastHint(testRelation).join(BroadcastHint(testRelation2).select(b)).select(a)) + } + + test("deep self join") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), + table("TaBlE").join(table("TaBlE")).join(table("TaBlE")).join(table("TaBlE")).select(a)), + BroadcastHint(testRelation).join(testRelation).join(testRelation).join(testRelation) + .select(a)) + } + + test("subquery should be ignored") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), + table("TaBlE").select(a).as("x").join(table("TaBlE")).select(a)), + testRelation.select(a).join(BroadcastHint(testRelation)).select(a)) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), + table("TaBlE").as("t").select(a).as("x") + .join(table("TaBlE2").as("t2")).select(a)), + testRelation.select(a).join(testRelation2).select(a)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fbe236e19626..1b2f3885a034 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -447,4 +447,46 @@ class PlanParserSuite extends PlanTest { assertEqual("select a, b from db.c where x !> 1", table("db", "c").where('x <= 1).select('a, 'b)) } + + test("select hint syntax") { + // Hive compatibility: Missing parameter raises ParseException. + val m = intercept[ParseException] { + parsePlan("SELECT /*+ HINT() */ * FROM t") + }.getMessage + assert(m.contains("no viable alternative at input")) + + // Hive compatibility: No database. + val m2 = intercept[ParseException] { + parsePlan("SELECT /*+ MAPJOIN(default.t) */ * from default.t") + }.getMessage + assert(m2.contains("no viable alternative at input")) + + comparePlans( + parsePlan("SELECT /*+ HINT */ * FROM t"), + Hint("HINT", Seq.empty, table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"), + Hint("BROADCASTJOIN", Seq("u"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"), + Hint("MAPJOIN", Seq("u"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"), + Hint("STREAMTABLE", Seq("a", "b", "c"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ INDEX(t emp_job_ix) */ * FROM t"), + Hint("INDEX", Seq("t", "emp_job_ix"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"), + Hint("MAPJOIN", Seq("default.t"), table("default.t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), + Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala index 5d93419f357e..15656979f665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala @@ -172,6 +172,10 @@ class SQLBuilder private ( toSQL(p.right), p.condition.map(" ON " + _.sql).getOrElse("")) + // Hint on aliased table should be matched directly. Otherwise, this Hint will be propagate up. + case h @ Hint(_, _, s @ SubqueryAlias(alias, p @ Project(_, _: SQLTable))) => + build("(" + toSQL(p.copy(child = h.copy(child = p.child))) + ")", "AS", s.alias) + case SQLTable(database, table, _, sample) => val qualifiedName = s"${quoteIdentifier(database)}.${quoteIdentifier(table)}" sample.map { case (lowerBound, upperBound) => @@ -208,6 +212,9 @@ class SQLBuilder private ( case OneRowRelation => "" + case Hint(_, _, child) => + toSQL(child) + case _ => throw new UnsupportedOperationException(s"unsupported plan $node") } @@ -220,14 +227,24 @@ class SQLBuilder private ( private def build(segments: String*): String = segments.map(_.trim).filter(_.nonEmpty).mkString(" ") - private def projectToSQL(plan: Project, isDistinct: Boolean): String = { - build( - "SELECT", - if (isDistinct) "DISTINCT" else "", - plan.projectList.map(_.sql).mkString(", "), - if (plan.child == OneRowRelation) "" else "FROM", - toSQL(plan.child) - ) + private def projectToSQL(plan: Project, isDistinct: Boolean): String = plan match { + case p @ Project(projectList, Hint("BROADCAST", tables, child)) => + build( + "SELECT", + if (tables.nonEmpty) s"/*+ MAPJOIN(${tables.mkString(", ")}) */" else "", + if (isDistinct) "DISTINCT" else "", + plan.projectList.map(_.sql).mkString(", "), + if (child == OneRowRelation) "" else "FROM", + toSQL(child) + ) + case _ => + build( + "SELECT", + if (isDistinct) "DISTINCT" else "", + plan.projectList.map(_.sql).mkString(", "), + if (plan.child == OneRowRelation) "" else "FROM", + toSQL(plan.child) + ) } private def scriptTransformationToSQL(plan: ScriptTransformation): String = { @@ -425,7 +442,9 @@ class SQLBuilder private ( // Insert sub queries on top of operators that need to appear after FROM clause. AddSubquery, // Reconstruct subquery expressions. - ConstructSubqueryExpressions + ConstructSubqueryExpressions, + // Normalize BroadcastHints to reconstruct hint comments. + NormalizeBroadcastHint ) ) @@ -438,6 +457,46 @@ class SQLBuilder private ( } } + /** + * Merge and move upward to the nearest Project. + * A broadcast hint comment is scattered into multiple nodes inside the plan, and the + * information of BroadcastHint resides its current position inside the plan. In order to + * reconstruct broadcast hint comment, we need to pack the information of BroadcastHint into + * Hint("BROADCAST", _, _) and collect them up by moving upward to the nearest Project node. + */ + object NormalizeBroadcastHint extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + // Capture the broadcasted information and store it in Hint. + case BroadcastHint(child @ SubqueryAlias(_, Project(_, SQLTable(database, table, _, _)))) => + Hint("BROADCAST", Seq(table), child) + + // Nearest Project is found. + case p @ Project(_, Hint(_, _, _)) => p + + // Merge BROADCAST hints up to the nearest Project. + case Hint("BROADCAST", params1, h @ Hint("BROADCAST", params2, _)) => + h.copy(parameters = params1 ++ params2) + case j @ Join(h1 @ Hint("BROADCAST", p1, left), h2 @ Hint("BROADCAST", p2, right), _, _) => + h1.copy(parameters = p1 ++ p2, child = j.copy(left = left, right = right)) + + // Bubble up BROADCAST hints to the nearest Project. + case j @ Join(h @ Hint("BROADCAST", _, hintChild), _, _, _) => + h.copy(child = j.copy(left = hintChild)) + case j @ Join(_, h @ Hint("BROADCAST", _, hintChild), _, _) => + h.copy(child = j.copy(right = hintChild)) + + // Other UnaryNodes are bypassed. + case u: UnaryNode + if u.child.isInstanceOf[Hint] && u.child.asInstanceOf[Hint].name.equals("BROADCAST") => + val hint = u.child.asInstanceOf[Hint] + hint.copy(child = u.withNewChildren(Seq(hint.child))) + + // Other binary(CoGroup/Intersect/Except) and Union are ignored. + // - CoGroup is not used in SQL. + // - Intersect/Except/Union have Project nodes inside. + } + } + object RemoveSubqueriesAboveSQLTable extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case SubqueryAlias(_, t @ ExtractSQLTable(_)) => t @@ -569,6 +628,8 @@ class SQLBuilder private ( case _: SQLTable => plan case _: Generate => plan case OneRowRelation => plan + case _: BroadcastHint => plan + case _: Hint => plan case _ => addSubquery(plan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 97adffa8ce10..1e43fc093e99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -153,4 +153,94 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { cases.foreach(assertBroadcastJoin) } } + + test("Broadcast Hint") { + import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, Join} + + spark.range(10).createOrReplaceTempView("t") + spark.range(10).createOrReplaceTempView("u") + + for (name <- Seq("BROADCAST", "BROADCASTJOIN", "MAPJOIN")) { + val plan1 = sql(s"SELECT /*+ $name(t) */ * FROM t JOIN u ON t.id = u.id").queryExecution + .optimizedPlan + val plan2 = sql(s"SELECT /*+ $name(u) */ * FROM t JOIN u ON t.id = u.id").queryExecution + .optimizedPlan + val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution + .optimizedPlan + + assert(plan1.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) + assert(!plan1.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + assert(!plan2.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) + assert(plan2.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + assert(!plan3.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) + assert(!plan3.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + } + } + + test("Broadcast Hint matches the nearest one") { + val tbl_a = spark.range(10) + val tbl_b = spark.range(20) + val tbl_c = spark.range(30) + + tbl_a.createOrReplaceTempView("tbl_a") + tbl_b.createOrReplaceTempView("tbl_b") + tbl_c.createOrReplaceTempView("tbl_c") + + val plan = sql( + """SELECT /*+ MAPJOIN(tbl_b) */ + | * + |FROM tbl_a A + | JOIN tbl_b B + | ON B.id = A.id + | JOIN (SELECT XA.id + | FROM tbl_b XA + | LEFT SEMI JOIN tbl_c XB + | ON XB.id = XA.id) C + | ON C.id = A.id + """.stripMargin).queryExecution.analyzed + + val correct_answer = + tbl_a.as("tbl_a").as("A") + .join(broadcast(tbl_b.as("tbl_b")).as("B"), $"B.id" === $"A.id", "inner") + .join(tbl_b.as("tbl_b").as("XA") + .join(tbl_c.as("tbl_c").as("XB"), $"XB.id" === $"XA.id", "leftsemi") + .select("XA.id").as("C"), $"C.id" === $"A.id", "inner") + .select(col("*")).logicalPlan + + comparePlans(plan, correct_answer) + } + + test("Nested Broadcast Hint") { + val tbl_a = spark.range(10) + val tbl_b = spark.range(20) + val tbl_c = spark.range(30) + + tbl_a.createOrReplaceTempView("tbl_a") + tbl_b.createOrReplaceTempView("tbl_b") + tbl_c.createOrReplaceTempView("tbl_c") + + val plan = sql( + """SELECT /*+ MAPJOIN(tbl_a, tbl_a) */ + | * + |FROM tbl_a A + | JOIN tbl_b B + | ON B.id = A.id + | JOIN (SELECT /*+ MAPJOIN(tbl_c) */ + | XA.id + | FROM tbl_b XA + | LEFT SEMI JOIN tbl_c XB + | ON XB.id = XA.id) C + | ON C.id = A.id + """.stripMargin).queryExecution.analyzed + + val correct_answer = + broadcast(tbl_a.as("tbl_a")).as("A") + .join(tbl_b.as("tbl_b").as("B"), $"B.id" === $"A.id", "inner") + .join(tbl_b.as("tbl_b").as("XA") + .join(broadcast(tbl_c.as("tbl_c")).as("XB"), $"XB.id" === $"XA.id", "leftsemi") + .select("XA.id").as("C"), $"C.id" === $"A.id", "inner") + .select(col("*")).logicalPlan + + comparePlans(plan, correct_answer) + } } diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_generator.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_generator.sql new file mode 100644 index 000000000000..dbf8ff55dae9 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_generator.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT * FROM (SELECT /*+ MAPJOIN(parquet_t0) */ EXPLODE(ARRAY(1,2,3)) FROM parquet_t0) T +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `col` FROM (SELECT `gen_attr_0` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0 LATERAL VIEW explode(array(1, 2, 3)) gen_subquery_1 AS `gen_attr_0`) AS T) AS T diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupby_having_orderby.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupby_having_orderby.sql new file mode 100644 index 000000000000..211c6360926d --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupby_having_orderby.sql @@ -0,0 +1,9 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * +FROM parquet_t0 +WHERE id > 0 +GROUP BY id +HAVING count(*) > 0 +ORDER BY id +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 WHERE (`gen_attr_0` > CAST(0 AS BIGINT)) GROUP BY `gen_attr_0` HAVING (`gen_attr_1` > CAST(0 AS BIGINT))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupingset.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupingset.sql new file mode 100644 index 000000000000..9d670dd2b169 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_groupingset.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t1) */ + count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3 +FROM parquet_t1 +GROUP BY key % 5, key - 5 +GROUPING SETS (key % 5, key - 5) +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `cnt`, `gen_attr_4` AS `k1`, `gen_attr_5` AS `k2`, `gen_attr_6` AS `k3` FROM (SELECT count(1) AS `gen_attr_3`, (`gen_attr_7` % CAST(5 AS BIGINT)) AS `gen_attr_4`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_5`, grouping_id() AS `gen_attr_6` FROM (SELECT /*+ MAPJOIN(parquet_t1) */ `key` AS `gen_attr_7`, `value` AS `gen_attr_8` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY (`gen_attr_7` % CAST(5 AS BIGINT)), (`gen_attr_7` - CAST(5 AS BIGINT)) GROUPING SETS(((`gen_attr_7` % CAST(5 AS BIGINT))), ((`gen_attr_7` - CAST(5 AS BIGINT))))) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_1.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_1.sql new file mode 100644 index 000000000000..999d209db752 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0, parquet_t1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id`, `gen_attr_1` AS `key`, `gen_attr_2` AS `value` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_2` FROM `default`.`parquet_t1`) AS gen_subquery_1) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_2.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_2.sql new file mode 100644 index 000000000000..ce855749672a --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_multiple_table_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t1) */ * FROM parquet_t0, parquet_t1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id`, `gen_attr_1` AS `key`, `gen_attr_2` AS `value` FROM (SELECT /*+ MAPJOIN(parquet_t1) */ `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_2` FROM `default`.`parquet_t1`) AS gen_subquery_1) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_rollup.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_rollup.sql new file mode 100644 index 000000000000..f40fcb6731b1 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_rollup.sql @@ -0,0 +1,7 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t1) */ + count(*) as cnt, key%5, grouping_id() +FROM parquet_t1 +GROUP BY key % 5 WITH ROLLUP +-------------------------------------------------------------------------------- +SELECT `gen_attr_2` AS `cnt`, `gen_attr_3` AS `(key % CAST(5 AS BIGINT))`, `gen_attr_4` AS `grouping_id()` FROM (SELECT count(1) AS `gen_attr_2`, (`gen_attr_5` % CAST(5 AS BIGINT)) AS `gen_attr_3`, grouping_id() AS `gen_attr_4` FROM (SELECT /*+ MAPJOIN(parquet_t1) */ `key` AS `gen_attr_5`, `value` AS `gen_attr_6` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY (`gen_attr_5` % CAST(5 AS BIGINT)) GROUPING SETS(((`gen_attr_5` % CAST(5 AS BIGINT))), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_1.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_1.sql new file mode 100644 index 000000000000..6a4c16470d1f --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_2.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_2.sql new file mode 100644 index 000000000000..8ef91e82b518 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0, parquet_t0) */ * FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_3.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_3.sql new file mode 100644 index 000000000000..9cb48ff62f7d --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_single_table_3.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 as a +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM ((SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS a) AS a diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_window.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_window.sql new file mode 100644 index 000000000000..8ed728e8c97e --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_window.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t1) */ + x.key, MAX(y.key) OVER (PARTITION BY x.key % 5 ORDER BY x.key) +FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `max(key) OVER (PARTITION BY (key % CAST(5 AS BIGINT)) ORDER BY key ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_2.`gen_attr_0`, gen_subquery_2.`gen_attr_2`, gen_subquery_2.`gen_attr_3`, max(`gen_attr_2`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_0` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_1` FROM (SELECT /*+ MAPJOIN(parquet_t1) */ `gen_attr_0`, `gen_attr_2`, (`gen_attr_0` % CAST(5 AS BIGINT)) AS `gen_attr_3` FROM ((SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS x INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_5` FROM `default`.`parquet_t1`) AS gen_subquery_1 ON (`gen_attr_0` = `gen_attr_2`)) AS gen_subquery_2) AS gen_subquery_3) AS x diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter.sql new file mode 100644 index 000000000000..a413f25b27d0 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 WHERE id < 10 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 WHERE (`gen_attr_0` < CAST(10 AS BIGINT))) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter_limit.sql b/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter_limit.sql new file mode 100644 index 000000000000..2671b60c60e7 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_hint_with_filter_limit.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 WHERE id < 10 LIMIT 10 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT /*+ MAPJOIN(parquet_t0) */ `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 WHERE (`gen_attr_0` < CAST(10 AS BIGINT)) LIMIT 10) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql index 3e2111d58a3c..408501bbe378 100644 --- a/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql +++ b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql @@ -5,4 +5,4 @@ FROM (SELECT x.key as key1, x.value as value1, y.key as key2, y.value as value2 JOIN srcpart z ON (subq.key1 = z.key and z.ds='2008-04-08' and z.hr=11) ORDER BY subq.key1, z.value -------------------------------------------------------------------------------- -SELECT `gen_attr_0` AS `key1`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_7` AS `gen_attr_6`, `gen_attr_9` AS `gen_attr_8`, `gen_attr_11` AS `gen_attr_10` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_7` FROM `default`.`src1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`src`) AS gen_subquery_1 ON (`gen_attr_5` = `gen_attr_9`)) AS subq INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1`, `ds` AS `gen_attr_3`, `hr` AS `gen_attr_4` FROM `default`.`srcpart`) AS gen_subquery_2 ON (((`gen_attr_0` = `gen_attr_2`) AND (`gen_attr_3` = "2008-04-08")) AND (CAST(`gen_attr_4` AS DOUBLE) = CAST(11 AS DOUBLE))) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_3 +SELECT `gen_attr_0` AS `key1`, `gen_attr_1` AS `value` FROM (SELECT /*+ MAPJOIN(srcpart) */ `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_7` AS `gen_attr_6`, `gen_attr_9` AS `gen_attr_8`, `gen_attr_11` AS `gen_attr_10` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_7` FROM `default`.`src1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`src`) AS gen_subquery_1 ON (`gen_attr_5` = `gen_attr_9`)) AS subq INNER JOIN ((SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1`, `ds` AS `gen_attr_3`, `hr` AS `gen_attr_4` FROM `default`.`srcpart`) AS gen_subquery_2) AS z ON (((`gen_attr_0` = `gen_attr_2`) AND (`gen_attr_3` = "2008-04-08")) AND (CAST(`gen_attr_4` AS DOUBLE) = CAST(11 AS DOUBLE))) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_3 diff --git a/sql/hive/src/test/resources/sqlgen/multiple_broadcast_hints.sql b/sql/hive/src/test/resources/sqlgen/multiple_broadcast_hints.sql new file mode 100644 index 000000000000..757f0567f47f --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/multiple_broadcast_hints.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(parquet_t0, parquet_t1) */ * FROM parquet_t0, parquet_t1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id`, `gen_attr_1` AS `key`, `gen_attr_2` AS `value` FROM (SELECT /*+ MAPJOIN(parquet_t0, parquet_t1) */ `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_2` FROM `default`.`parquet_t1`) AS gen_subquery_1) AS gen_subquery_2 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index d8ab864ca6fc..bb6fb1a174a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -1103,4 +1103,95 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSQL("select * from orc_t", "select_orc_table") } } + + test("broadcast hint on single table") { + checkSQL("SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0", + "broadcast_hint_single_table_1") + + checkSQL("SELECT /*+ MAPJOIN(parquet_t0, parquet_t0) */ * FROM parquet_t0", + "broadcast_hint_single_table_2") + + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 as a", + "broadcast_hint_single_table_3") + } + + test("broadcast hint on multiple tables") { + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0, parquet_t1", + "broadcast_hint_multiple_table_1") + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t1) */ * FROM parquet_t0, parquet_t1", + "broadcast_hint_multiple_table_2") + } + + test("multiple broadcast hints on multiple tables") { + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t0, parquet_t1) */ * FROM parquet_t0, parquet_t1", + "multiple_broadcast_hints") + } + + test("broadcast hint with filter") { + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 WHERE id < 10", + "broadcast_hint_with_filter") + } + + test("broadcast hint with filter/limit") { + checkSQL( + "SELECT /*+ MAPJOIN(parquet_t0) */ * FROM parquet_t0 WHERE id < 10 LIMIT 10", + "broadcast_hint_with_filter_limit") + } + + test("broadcast hint with generator") { + checkSQL( + "SELECT * FROM (SELECT /*+ MAPJOIN(parquet_t0) */ EXPLODE(ARRAY(1,2,3)) FROM parquet_t0) T", + "broadcast_hint_generator") + } + + test("broadcast hint with groupby/having/orderby") { + checkSQL( + """ + |SELECT /*+ MAPJOIN(parquet_t0) */ * + |FROM parquet_t0 + |WHERE id > 0 + |GROUP BY id + |HAVING count(*) > 0 + |ORDER BY id + """.stripMargin, + "broadcast_hint_groupby_having_orderby") + } + + test("broadcast hint with window") { + checkSQL( + """ + |SELECT /*+ MAPJOIN(parquet_t1) */ + | x.key, MAX(y.key) OVER (PARTITION BY x.key % 5 ORDER BY x.key) + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key + """.stripMargin, + "broadcast_hint_window") + } + + test("broadcast hint with rollup") { + checkSQL( + """ + |SELECT /*+ MAPJOIN(parquet_t1) */ + | count(*) as cnt, key%5, grouping_id() + |FROM parquet_t1 + |GROUP BY key % 5 WITH ROLLUP + """.stripMargin, + "broadcast_hint_rollup") + } + + test("broadcast hint with grouping sets") { + checkSQL( + s""" + |SELECT /*+ MAPJOIN(parquet_t1) */ + | count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3 + |FROM parquet_t1 + |GROUP BY key % 5, key - 5 + |GROUPING SETS (key % 5, key - 5) + """.stripMargin, + "broadcast_hint_groupingset") + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/BroadcastHintSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/BroadcastHintSuite.scala new file mode 100644 index 000000000000..928064a95fec --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/BroadcastHintSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, Join} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class BroadcastHintSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + test("broadcast hint on Hive table") { + withTable("hive_t", "hive_u") { + spark.sql("CREATE TABLE hive_t(a int)") + spark.sql("CREATE TABLE hive_u(b int)") + + val hive_t = spark.table("hive_t").queryExecution.analyzed + val hive_u = spark.table("hive_u").queryExecution.analyzed + + val plan = spark.sql("SELECT /*+ MAPJOIN(hive_t) */ * FROM hive_t, hive_u") + .queryExecution.analyzed + + assert(plan.collectFirst { + case BroadcastHint(MetastoreRelation(_, "hive_t")) => true + }.isDefined) + assert(plan.collectFirst { + case Join(_, MetastoreRelation(_, "hive_u"), _, _) => true + }.isDefined) + + val plan2 = spark.sql("SELECT /*+ MAPJOIN(hive_u) */ a FROM hive_t, hive_u") + .queryExecution.analyzed + + assert(plan2.collectFirst { + case BroadcastHint(MetastoreRelation(_, "hive_u")) => true + }.isDefined) + assert(plan2.collectFirst { + case Join(MetastoreRelation(_, "hive_t"), _, _, _) => true + }.isDefined) + } + } +}