diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a84bb7653c52..9e921899cbe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -146,7 +146,7 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, - new ResolveHints.ResolveBroadcastHints(conf), + new ResolveHints.ResolveBroadcastHints(conf, catalog), ResolveHints.ResolveCoalesceHints, ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index dbd4ed845e32..05c7c7c9dc0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.IdentifierWithDatabase +import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions.IntegerLiteral import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -47,20 +49,42 @@ object ResolveHints { * * This rule must happen before common table expressions. */ - class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] { + class ResolveBroadcastHints(conf: SQLConf, catalog: SessionCatalog) extends Rule[LogicalPlan] { private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") def resolver: Resolver = conf.resolver - private def applyBroadcastHint(plan: LogicalPlan, toBroadcast: Set[String]): LogicalPlan = { + // Name resolution in hints follows three rules below: + // + // 1. table name matches if the hint table name only has one part + // 2. table name and database name both match if the hint table name has two parts + // 3. no match happens if the hint table name has more than three parts + // + // This means, `SELECT /* BROADCAST(t) */ * FROM db1.t JOIN db2.t` will match both tables, and + // `SELECT /* BROADCAST(default.t) */ * FROM t` match no table. + private def matchedTableIdentifier( + nameParts: Seq[String], + tableIdent: IdentifierWithDatabase): Boolean = nameParts match { + case Seq(tableName) => + resolver(tableIdent.identifier, tableName) + case Seq(dbName, tableName) if tableIdent.database.isDefined => + resolver(tableIdent.database.get, dbName) && resolver(tableIdent.identifier, tableName) + case _ => + false + } + + private def applyBroadcastHint( + plan: LogicalPlan, + toBroadcast: Set[Seq[String]]): LogicalPlan = { // Whether to continue recursing down the tree var recurse = true val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { - case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) => + case u: UnresolvedRelation + if toBroadcast.exists(matchedTableIdentifier(_, u.tableIdentifier)) => ResolvedHint(plan, HintInfo(broadcast = true)) - case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) => + case r: SubqueryAlias if toBroadcast.exists(matchedTableIdentifier(_, r.name)) => ResolvedHint(plan, HintInfo(broadcast = true)) case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => @@ -94,8 +118,8 @@ object ResolveHints { } else { // Otherwise, find within the subtree query plans that should be broadcasted. applyBroadcastHint(h.child, h.parameters.map { - case tableName: String => tableName - case tableId: UnresolvedAttribute => tableId.name + case tableName: String => UnresolvedAttribute.parseAttributeName(tableName) + case tableId: UnresolvedAttribute => tableId.nameParts case unsupported => throw new AnalysisException("Broadcast hint parameter should be " + s"an identifier or string but was $unsupported (${unsupported.getClass}") }.toSet) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 932c36473724..785f35dd9c97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -168,7 +168,7 @@ package object expressions { // For example, consider an example where "db1" is the database name, "a" is the table name // and "b" is the column name and "c" is the struct field name. // If the name parts is db1.a.b.c, then Attribute will match - // Attribute(b, qualifier("db1,"a")) and List("c") will be the second element + // Attribute(b, qualifier("db1","a")) and List("c") will be the second element var matches: (Seq[Attribute], Seq[String]) = nameParts match { case dbPart +: tblPart +: name +: nestedFields => val key = (dbPart.toLowerCase(Locale.ROOT), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index fab1b776a3c7..a9ecb6bd9df1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -41,6 +41,8 @@ trait AnalysisTest extends PlanTest { catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true) + catalog.createGlobalTempView("TaBlE4", TestRelations.testRelation4, overrideIfExists = true) + catalog.createGlobalTempView("TaBlE5", TestRelations.testRelation5, overrideIfExists = true) new Analyzer(catalog, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 563e8adf87ed..a5855447e603 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -155,4 +155,52 @@ class ResolveHintsSuite extends AnalysisTest { UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")), Seq(errMsgRepa)) } + + test("supports multi-part table names for broadcast hint resolution") { + // local temp table + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("table", "table2"), + table("table").join(table("table2"))), + Join( + ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation2, HintInfo(broadcast = true)), + Inner, + None, + JoinHint(None, None)), + caseSensitive = false) + + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("TaBlE", "table2"), + table("TaBlE").join(table("TaBlE2"))), + Join( + ResolvedHint(testRelation, HintInfo(broadcast = true)), + testRelation2, + Inner, + None, + JoinHint(None, None)), + caseSensitive = true) + + // global temp table + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("global_temp.table4", "GlOBal_TeMP.table5"), + table("global_temp", "table4").join(table("global_temp", "table5"))), + Join( + ResolvedHint(testRelation4, HintInfo(broadcast = true)), + ResolvedHint(testRelation5, HintInfo(broadcast = true)), + Inner, + None, + JoinHint(None, None)), + caseSensitive = false) + + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("global_temp.TaBlE4", "table5"), + table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))), + Join( + ResolvedHint(testRelation4, HintInfo(broadcast = true)), + testRelation5, + Inner, + None, + JoinHint(None, None)), + caseSensitive = true) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index e12e272aedff..33b602907093 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -44,6 +44,8 @@ object TestRelations { AttributeReference("g", StringType)(), AttributeReference("h", MapType(IntegerType, IntegerType))()) + val testRelation5 = LocalRelation(AttributeReference("i", StringType)()) + val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 6bd12cbf0135..b8b4f3b6a6fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -191,6 +194,83 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1) } + test("SPARK-25121 supports multi-part names for broadcast hint resolution") { + val (table1Name, table2Name) = ("t1", "t2") + + withTempDatabase { dbName => + withTable(table1Name, table2Name) { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + spark.range(50).write.saveAsTable(s"$dbName.$table1Name") + spark.range(100).write.saveAsTable(s"$dbName.$table2Name") + + // First, makes sure a join is not broadcastable + val plan = sql(s"SELECT * FROM $dbName.$table1Name, $dbName.$table2Name " + + s"WHERE $table1Name.id = $table2Name.id") + .queryExecution.executedPlan + assert(plan.collect { case p: BroadcastHashJoinExec => p }.isEmpty) + + def checkIfHintApplied(tableName: String, hintTableName: String): Unit = { + val p = sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " + + s"FROM $tableName, $dbName.$table2Name " + + s"WHERE $tableName.id = $table2Name.id") + .queryExecution.executedPlan + val broadcastHashJoins = p.collect { case p: BroadcastHashJoinExec => p } + assert(broadcastHashJoins.size == 1) + val broadcastExchanges = broadcastHashJoins.head.collect { + case p: BroadcastExchangeExec => p + } + assert(broadcastExchanges.size == 1) + val tables = broadcastExchanges.head.collect { + case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent + } + assert(tables.size == 1) + assert(tables.head === TableIdentifier(table1Name, Some(dbName))) + } + + def checkIfHintNotApplied(tableName: String, hintTableName: String): Unit = { + val p = sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " + + s"FROM $tableName, $dbName.$table2Name " + + s"WHERE $tableName.id = $table2Name.id") + .queryExecution.executedPlan + val broadcastHashJoins = p.collect { case p: BroadcastHashJoinExec => p } + assert(broadcastHashJoins.isEmpty) + } + + sql(s"USE $dbName") + checkIfHintApplied(table1Name, table1Name) + checkIfHintApplied(s"$dbName.$table1Name", s"$dbName.$table1Name") + checkIfHintApplied(s"$dbName.$table1Name", table1Name) + checkIfHintNotApplied(table1Name, s"$dbName.$table1Name") + checkIfHintNotApplied(s"$dbName.$table1Name", s"$dbName.$table1Name.id") + } + } + } + } + + test("SPARK-25121 the same table name exists in two databases for broadcast hint resolution") { + val (db1Name, db2Name) = ("db1", "db2") + + withDatabase(db1Name, db2Name) { + withTable("t") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + sql(s"CREATE DATABASE $db1Name") + sql(s"CREATE DATABASE $db2Name") + spark.range(1).write.saveAsTable(s"$db1Name.t") + spark.range(1).write.saveAsTable(s"$db2Name.t") + + // Checks if a broadcast hint applied in both sides + val statement = s"SELECT /*+ BROADCASTJOIN(t) */ * FROM $db1Name.t, $db2Name.t " + + s"WHERE $db1Name.t.id = $db2Name.t.id" + sql(statement).queryExecution.optimizedPlan match { + case Join(_, _, _, _, JoinHint(Some(leftHint), Some(rightHint))) => + assert(leftHint.broadcast && rightHint.broadcast) + case _ => fail("broadcast hint not found in both tables") + } + } + } + } + } + test("join - outer join conversion") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index 972b47e96fe0..bd85b4b1d2db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalog.Table import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -157,6 +158,27 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-25121 broadcast hint on global temp view") { + withGlobalTempView("v1") { + spark.range(10).createGlobalTempView("v1") + withTempView("v2") { + spark.range(10).createTempView("v2") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + "SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id", + "SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id" + ).foreach { statement => + sql(statement).queryExecution.optimizedPlan match { + case Join(_, _, _, _, JoinHint(Some(leftHint), None)) => assert(leftHint.broadcast) + case _ => fail("broadcast hint not found in a left-side table") + } + } + } + } + } + } + test("public Catalog should recognize global temp view") { withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 8269d4d3a285..257cac4d05c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, SubqueryAlias} +import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} class SimpleSQLViewSuite extends SQLViewSuite with SharedSQLContext @@ -706,4 +709,39 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } } } + + test("SPARK-25121 broadcast hint on temp view") { + withTable("t") { + spark.range(10).write.saveAsTable("t") + withTempView("tv") { + sql("CREATE TEMPORARY VIEW tv AS SELECT * FROM t") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + // First, makes sure a join is not broadcastable + val plan1 = sql("SELECT * FROM t, tv WHERE t.id = tv.id") + .queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0) + + // `MAPJOIN(default.tv)` cannot match the temporary table `tv` + val plan2 = sql("SELECT /*+ MAPJOIN(default.tv) */ * FROM t, tv WHERE t.id = tv.id") + .queryExecution.analyzed + assert(plan2.collect { case h: ResolvedHint => h }.size == 0) + + // `MAPJOIN(tv)` can match the temporary table `tv` + val df = sql("SELECT /*+ MAPJOIN(tv) */ * FROM t, tv WHERE t.id = tv.id") + val logicalPlan = df.queryExecution.analyzed + val broadcastData = logicalPlan.collect { + case ResolvedHint(SubqueryAlias(name, _), _) => name + } + assert(broadcastData.size == 1) + assert(broadcastData.head.database === None) + assert(broadcastData.head.identifier === "tv") + + val sparkPlan = df.queryExecution.executedPlan + val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p } + assert(broadcastHashJoins.size == 1) + } + } + } + } }