From 24de79929737bcedcfdae8f81173f807a27fbdf2 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 23 Aug 2018 16:20:51 +0900 Subject: [PATCH 01/11] Fix --- .../sql/catalyst/analysis/ResolveHints.scala | 21 ++++++++--- .../sql/catalyst/expressions/package.scala | 2 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 37 +++++++++++++++++++ 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index dbd4ed845e32..2208b6016b86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.IdentifierWithDatabase import org.apache.spark.sql.catalyst.expressions.IntegerLiteral import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -52,15 +53,25 @@ object ResolveHints { def resolver: Resolver = conf.resolver - private def applyBroadcastHint(plan: LogicalPlan, toBroadcast: Set[String]): LogicalPlan = { + private def matchedTableIdentifier( + nameParts: Seq[String], + tableIdent: IdentifierWithDatabase): Boolean = { + val identifierList = tableIdent.database.map(_ :: Nil).getOrElse(Nil) :+ tableIdent.identifier + nameParts.corresponds(identifierList)(resolver) + } + + private def applyBroadcastHint( + plan: LogicalPlan, + toBroadcast: Set[Seq[String]]): LogicalPlan = { // Whether to continue recursing down the tree var recurse = true val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { - case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) => + case u: UnresolvedRelation + if toBroadcast.exists(matchedTableIdentifier(_, u.tableIdentifier)) => ResolvedHint(plan, HintInfo(broadcast = true)) - case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) => + case r: SubqueryAlias if toBroadcast.exists(matchedTableIdentifier(_, r.name)) => ResolvedHint(plan, HintInfo(broadcast = true)) case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => @@ -94,8 +105,8 @@ object ResolveHints { } else { // Otherwise, find within the subtree query plans that should be broadcasted. applyBroadcastHint(h.child, h.parameters.map { - case tableName: String => tableName - case tableId: UnresolvedAttribute => tableId.name + case tableName: String => UnresolvedAttribute.parseAttributeName(tableName) + case tableId: UnresolvedAttribute => tableId.nameParts case unsupported => throw new AnalysisException("Broadcast hint parameter should be " + s"an identifier or string but was $unsupported (${unsupported.getClass}") }.toSet) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 932c36473724..785f35dd9c97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -168,7 +168,7 @@ package object expressions { // For example, consider an example where "db1" is the database name, "a" is the table name // and "b" is the column name and "c" is the struct field name. // If the name parts is db1.a.b.c, then Attribute will match - // Attribute(b, qualifier("db1,"a")) and List("c") will be the second element + // Attribute(b, qualifier("db1","a")) and List("c") will be the second element var matches: (Seq[Attribute], Seq[String]) = nameParts match { case dbPart +: tblPart +: name +: nestedFields => val key = (dbPart.toLowerCase(Locale.ROOT), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 6bd12cbf0135..39f4026932de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -191,6 +195,39 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1) } + test("SPARK-25121 Supports multi-part names for broadcast hint resolution") { + val tableName = "t" + withTempDatabase { dbName => + withTable(tableName) { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + spark.range(100).write.saveAsTable(s"$dbName.$tableName") + // First, makes sure a join is not broadcastable + val plan1 = spark.range(3) + .join(spark.table(s"$dbName.$tableName"), "id") + .queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0) + + // Uses multi-part table names for broadcast hints + val plan2 = spark.range(3) + .join(spark.table(s"$dbName.$tableName"), "id") + .hint("broadcast", s"$dbName.$tableName") + .queryExecution.executedPlan + val broadcastHashJoin = plan2.collect { case p: BroadcastHashJoinExec => p } + assert(broadcastHashJoin.size == 1) + val broadcastExchange = broadcastHashJoin.head.collect { + case p: BroadcastExchangeExec => p + } + assert(broadcastExchange.size == 1) + val table = broadcastExchange.head.collect { + case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent + } + assert(table.size == 1) + assert(table.head === TableIdentifier(tableName, Some(dbName))) + } + } + } + } + test("join - outer join conversion") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") From a6e4e40ad039fa3dcc522c628ace2968e62ade4c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 25 Aug 2018 23:42:24 +0900 Subject: [PATCH 02/11] Fix --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/ResolveHints.scala | 16 ++++-- .../apache/spark/sql/DataFrameJoinSuite.scala | 49 +++++++++++-------- 3 files changed, 43 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a84bb7653c52..9e921899cbe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -146,7 +146,7 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, - new ResolveHints.ResolveBroadcastHints(conf), + new ResolveHints.ResolveBroadcastHints(conf, catalog), ResolveHints.ResolveCoalesceHints, ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 2208b6016b86..eaacaa57d08b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.IdentifierWithDatabase +import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions.IntegerLiteral import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -48,16 +49,25 @@ object ResolveHints { * * This rule must happen before common table expressions. */ - class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] { + class ResolveBroadcastHints(conf: SQLConf, catalog: SessionCatalog) extends Rule[LogicalPlan] { private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") def resolver: Resolver = conf.resolver + private def namePartsWithDatabase(nameParts: Seq[String]): Seq[String] = { + if (nameParts.size == 1) { + catalog.getCurrentDatabase +: nameParts + } else { + nameParts + } + } + private def matchedTableIdentifier( nameParts: Seq[String], tableIdent: IdentifierWithDatabase): Boolean = { - val identifierList = tableIdent.database.map(_ :: Nil).getOrElse(Nil) :+ tableIdent.identifier - nameParts.corresponds(identifierList)(resolver) + val identifierList = + tableIdent.database.getOrElse(catalog.getCurrentDatabase) :: tableIdent.identifier :: Nil + namePartsWithDatabase(nameParts).corresponds(identifierList)(resolver) } private def applyBroadcastHint( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 39f4026932de..63b4601e9111 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -196,33 +196,42 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { } test("SPARK-25121 Supports multi-part names for broadcast hint resolution") { - val tableName = "t" + val (table1Name, table2Name) = ("t1", "t2") withTempDatabase { dbName => - withTable(tableName) { + withTable(table1Name, table2Name) { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { - spark.range(100).write.saveAsTable(s"$dbName.$tableName") + spark.range(50).write.saveAsTable(s"$dbName.$table1Name") + spark.range(100).write.saveAsTable(s"$dbName.$table2Name") // First, makes sure a join is not broadcastable - val plan1 = spark.range(3) - .join(spark.table(s"$dbName.$tableName"), "id") + val plan = sql(s"SELECT * FROM $dbName.$table1Name, $dbName.$table2Name " + + s"WHERE $table1Name.id = $table2Name.id") .queryExecution.executedPlan - assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0) + assert(plan.collect { case p: BroadcastHashJoinExec => p }.size == 0) // Uses multi-part table names for broadcast hints - val plan2 = spark.range(3) - .join(spark.table(s"$dbName.$tableName"), "id") - .hint("broadcast", s"$dbName.$tableName") - .queryExecution.executedPlan - val broadcastHashJoin = plan2.collect { case p: BroadcastHashJoinExec => p } - assert(broadcastHashJoin.size == 1) - val broadcastExchange = broadcastHashJoin.head.collect { - case p: BroadcastExchangeExec => p - } - assert(broadcastExchange.size == 1) - val table = broadcastExchange.head.collect { - case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent + def checkIfHintApplied(tableName: String, hintTableName: String): Unit = { + val p = sql(s"SELECT /*+ BROADCASTJOIN($tableName) */ * " + + s"FROM $tableName, $dbName.$table2Name " + + s"WHERE $tableName.id = $table2Name.id") + .queryExecution.executedPlan + val broadcastHashJoin = p.collect { case p: BroadcastHashJoinExec => p } + assert(broadcastHashJoin.size == 1) + val broadcastExchange = broadcastHashJoin.head.collect { + case p: BroadcastExchangeExec => p + } + assert(broadcastExchange.size == 1) + val table = broadcastExchange.head.collect { + case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent + } + assert(table.size == 1) + assert(table.head === TableIdentifier(table1Name, Some(dbName))) } - assert(table.size == 1) - assert(table.head === TableIdentifier(tableName, Some(dbName))) + + sql(s"USE $dbName") + checkIfHintApplied(table1Name, table1Name) + checkIfHintApplied(s"$dbName.$table1Name", s"$dbName.$table1Name") + checkIfHintApplied(table1Name, s"$dbName.$table1Name") + checkIfHintApplied(s"$dbName.$table1Name", table1Name) } } } From f0217702ea8656f3a5572b928eda627e7774d211 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 25 Aug 2018 17:12:48 -0700 Subject: [PATCH 03/11] Add test cases --- .../catalyst/analysis/ResolveHintsSuite.scala | 9 +++++++++ .../sql/execution/GlobalTempViewSuite.scala | 20 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 563e8adf87ed..2f392f649ac0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -155,4 +155,13 @@ class ResolveHintsSuite extends AnalysisTest { UnresolvedHint("REPARTITION", Seq(Literal(true)), table("TaBlE")), Seq(errMsgRepa)) } + + test("Supports multi-part table names for broadcast hint resolution") { + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("default.table", "default.table2"), + table("table").join(table("table2"))), + Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None), + caseSensitive = false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index 972b47e96fe0..837d2847181a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -157,6 +157,26 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { } } + test("broadcast hint on global temp view") { + import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, Join} + + withGlobalTempView("v1") { + spark.range(10).createGlobalTempView("v1") + withTempView("v2") { + spark.range(10).createTempView("v2") + + Seq( + "SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id", + "SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id" + ).foreach { statement => + val plan = sql(statement).queryExecution.optimizedPlan + assert(plan.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(!plan.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + } + } + } + } + test("public Catalog should recognize global temp view") { withGlobalTempView("src") { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 2") From d434ba7fde52b8578324e394478c67037e9ee1b4 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 27 Aug 2018 09:10:28 +0900 Subject: [PATCH 04/11] Fix --- .../test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 63b4601e9111..6c28ba0953f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -210,7 +210,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // Uses multi-part table names for broadcast hints def checkIfHintApplied(tableName: String, hintTableName: String): Unit = { - val p = sql(s"SELECT /*+ BROADCASTJOIN($tableName) */ * " + + val p = sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " + s"FROM $tableName, $dbName.$table2Name " + s"WHERE $tableName.id = $table2Name.id") .queryExecution.executedPlan From c138b81023b19a6aaa7bfb438226b9f72405f557 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 27 Aug 2018 11:05:45 +0900 Subject: [PATCH 05/11] Fix --- .../sql/execution/GlobalTempViewSuite.scala | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index 837d2847181a..315ab26f0e03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -165,14 +165,15 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { withTempView("v2") { spark.range(10).createTempView("v2") - Seq( - "SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id", - "SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id" - ).foreach { statement => - val plan = sql(statement).queryExecution.optimizedPlan - assert(plan.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) - assert(!plan.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) - } + val plan1 = sql("SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id") + .queryExecution.optimizedPlan + assert(plan1.collectFirst { case h: ResolvedHint => h }.size == 0) + + val plan2 = sql("SELECT /*+ MAPJOIN(global_temp.v1) */ * " + + "FROM global_temp.v1, v2 WHERE v1.id = v2.id") + .queryExecution.optimizedPlan + assert(plan2.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(!plan2.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) } } } From 6a202f2292f8614a1a9a3bcd5d7e2a1b069f7b21 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 27 Aug 2018 13:08:43 +0900 Subject: [PATCH 06/11] Fix --- .../sql/catalyst/analysis/ResolveHints.scala | 22 ++++++++++++++---- .../sql/execution/GlobalTempViewSuite.scala | 23 ++++++++----------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index eaacaa57d08b..a3e155216811 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -54,20 +54,32 @@ object ResolveHints { def resolver: Resolver = conf.resolver - private def namePartsWithDatabase(nameParts: Seq[String]): Seq[String] = { + private def namePartsWithDatabase(nameParts: Seq[String], database: String): Seq[String] = { if (nameParts.size == 1) { - catalog.getCurrentDatabase +: nameParts + database +: nameParts } else { nameParts } } + private def formatDatabaseName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) + } + private def matchedTableIdentifier( nameParts: Seq[String], tableIdent: IdentifierWithDatabase): Boolean = { - val identifierList = - tableIdent.database.getOrElse(catalog.getCurrentDatabase) :: tableIdent.identifier :: Nil - namePartsWithDatabase(nameParts).corresponds(identifierList)(resolver) + tableIdent.database match { + case Some(db) if catalog.globalTempViewManager.database == formatDatabaseName(db) => + val identifierList = db :: tableIdent.identifier :: Nil + namePartsWithDatabase(nameParts, catalog.globalTempViewManager.database) + .corresponds(identifierList)(resolver) + case _ => + val db = tableIdent.database.getOrElse(catalog.getCurrentDatabase) + val identifierList = db :: tableIdent.identifier :: Nil + namePartsWithDatabase(nameParts, catalog.getCurrentDatabase) + .corresponds(identifierList)(resolver) + } } private def applyBroadcastHint( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index 315ab26f0e03..5998fb78fa2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalog.Table import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.{Join, ResolvedHint} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -157,23 +157,20 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { } } - test("broadcast hint on global temp view") { - import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, Join} - + test("SPARK-25121 broadcast hint on global temp view") { withGlobalTempView("v1") { spark.range(10).createGlobalTempView("v1") withTempView("v2") { spark.range(10).createTempView("v2") - val plan1 = sql("SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id") - .queryExecution.optimizedPlan - assert(plan1.collectFirst { case h: ResolvedHint => h }.size == 0) - - val plan2 = sql("SELECT /*+ MAPJOIN(global_temp.v1) */ * " + - "FROM global_temp.v1, v2 WHERE v1.id = v2.id") - .queryExecution.optimizedPlan - assert(plan2.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) - assert(!plan2.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + Seq( + "SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id", + "SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id" + ).foreach { statement => + val plan = sql(statement).queryExecution.optimizedPlan + assert(plan.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(!plan.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + } } } } From 545148b499ac1b9f552d5095edea8ef3ef92b4eb Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 27 Aug 2018 12:43:17 -0700 Subject: [PATCH 07/11] fix --- .../apache/spark/sql/catalyst/analysis/ResolveHints.scala | 6 +----- .../spark/sql/catalyst/analysis/ResolveHintsSuite.scala | 6 ++++++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index a3e155216811..a1543e1fd5df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -62,15 +62,11 @@ object ResolveHints { } } - private def formatDatabaseName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) - } - private def matchedTableIdentifier( nameParts: Seq[String], tableIdent: IdentifierWithDatabase): Boolean = { tableIdent.database match { - case Some(db) if catalog.globalTempViewManager.database == formatDatabaseName(db) => + case Some(db) if resolver(catalog.globalTempViewManager.database, db) => val identifierList = db :: tableIdent.identifier :: Nil namePartsWithDatabase(nameParts, catalog.globalTempViewManager.database) .corresponds(identifierList)(resolver) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 2f392f649ac0..839b831ae9a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -163,5 +163,11 @@ class ResolveHintsSuite extends AnalysisTest { Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None), caseSensitive = false) + + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("default.TaBlE", "default.table2", "DEfault.TaBlE2"), + table("TaBlE").join(table("TaBlE2"))), + Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), testRelation2, Inner, None), + caseSensitive = true) } } From bc29a11ad5cc827e6d028c8657ffbf40829d13e9 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 29 Aug 2018 10:53:51 +0900 Subject: [PATCH 08/11] Fix --- .../sql/catalyst/analysis/ResolveHints.scala | 2 + .../apache/spark/sql/DataFrameJoinSuite.scala | 2 +- .../sql/execution/GlobalTempViewSuite.scala | 17 +++++---- .../spark/sql/execution/SQLViewSuite.scala | 38 +++++++++++++++++++ 4 files changed, 51 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index a1543e1fd5df..1551a48cedf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -70,6 +70,8 @@ object ResolveHints { val identifierList = db :: tableIdent.identifier :: Nil namePartsWithDatabase(nameParts, catalog.globalTempViewManager.database) .corresponds(identifierList)(resolver) + case None if catalog.getTempView(tableIdent.identifier).isDefined => + nameParts.size == 1 && resolver(nameParts.head, tableIdent.identifier) case _ => val db = tableIdent.database.getOrElse(catalog.getCurrentDatabase) val identifierList = db :: tableIdent.identifier :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 6c28ba0953f5..8aa66dc3f9fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -199,7 +199,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val (table1Name, table2Name) = ("t1", "t2") withTempDatabase { dbName => withTable(table1Name, table2Name) { - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { spark.range(50).write.saveAsTable(s"$dbName.$table1Name") spark.range(100).write.saveAsTable(s"$dbName.$table2Name") // First, makes sure a join is not broadcastable diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index 5998fb78fa2f..97c0c4c432ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalog.Table import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{Join, ResolvedHint} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -163,13 +164,15 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { withTempView("v2") { spark.range(10).createTempView("v2") - Seq( - "SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id", - "SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id" - ).foreach { statement => - val plan = sql(statement).queryExecution.optimizedPlan - assert(plan.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) - assert(!plan.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + "SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id", + "SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id" + ).foreach { statement => + val plan = sql(statement).queryExecution.optimizedPlan + assert(plan.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(!plan.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 8269d4d3a285..9d56b6c1f1ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, SubqueryAlias} +import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} class SimpleSQLViewSuite extends SQLViewSuite with SharedSQLContext @@ -706,4 +709,39 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } } } + + test("SPARK-25121 broadcast hint on temp view") { + withTable("t") { + spark.range(10).write.saveAsTable("t") + withTempView("tv") { + sql("CREATE TEMPORARY VIEW tv AS SELECT * FROM t") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + // First, makes sure a join is not broadcastable + val plan1 = sql("SELECT * FROM t, tv WHERE t.id = tv.id") + .queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0) + + // `MAPJOIN(default.tv)` cannot match the temporary table `tv` + val plan2 = sql("SELECT /*+ MAPJOIN(default.tv) */ * FROM t, tv WHERE t.id = tv.id") + .queryExecution.analyzed + assert(plan2.collect { case h: ResolvedHint => h }.size == 0) + + // `MAPJOIN(tv)` can match the temporary table `tv` + val df = sql("SELECT /*+ MAPJOIN(tv) */ * FROM t, tv WHERE t.id = tv.id") + val logicalPlan = df.queryExecution.analyzed + val broadcastData = logicalPlan.collect { + case ResolvedHint(SubqueryAlias(name, _), _) => name + } + assert(broadcastData.size == 1) + assert(broadcastData.head.database === None) + assert(broadcastData.head.identifier === "tv") + + val sparkPlan = df.queryExecution.executedPlan + val broadcastHashJoin = sparkPlan.collect { case p: BroadcastHashJoinExec => p } + assert(broadcastHashJoin.size == 1) + } + } + } + } } From 59e60d43a68e0641fa5a66028bd5dcc01d5b0804 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 30 Aug 2018 09:42:19 +0900 Subject: [PATCH 09/11] Fix --- .../sql/catalyst/analysis/AnalysisTest.scala | 2 ++ .../catalyst/analysis/ResolveHintsSuite.scala | 19 +++++++++++++++++-- .../sql/catalyst/analysis/TestRelations.scala | 2 ++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index fab1b776a3c7..a9ecb6bd9df1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -41,6 +41,8 @@ trait AnalysisTest extends PlanTest { catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true) + catalog.createGlobalTempView("TaBlE4", TestRelations.testRelation4, overrideIfExists = true) + catalog.createGlobalTempView("TaBlE5", TestRelations.testRelation5, overrideIfExists = true) new Analyzer(catalog, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 839b831ae9a3..1014c49cf8d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -157,17 +157,32 @@ class ResolveHintsSuite extends AnalysisTest { } test("Supports multi-part table names for broadcast hint resolution") { + // local temp table checkAnalysis( - UnresolvedHint("MAPJOIN", Seq("default.table", "default.table2"), + UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None), caseSensitive = false) checkAnalysis( - UnresolvedHint("MAPJOIN", Seq("default.TaBlE", "default.table2", "DEfault.TaBlE2"), + UnresolvedHint("MAPJOIN", Seq("TaBlE", "table2"), table("TaBlE").join(table("TaBlE2"))), Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), testRelation2, Inner, None), caseSensitive = true) + + // global temp table + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("global_temp.table4", "GlOBal_TeMP.table5"), + table("global_temp", "table4").join(table("global_temp", "table5"))), + Join(ResolvedHint(testRelation4, HintInfo(broadcast = true)), + ResolvedHint(testRelation5, HintInfo(broadcast = true)), Inner, None), + caseSensitive = false) + + checkAnalysis( + UnresolvedHint("MAPJOIN", Seq("global_temp.TaBlE4", "table5"), + table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))), + Join(ResolvedHint(testRelation4, HintInfo(broadcast = true)), testRelation5, Inner, None), + caseSensitive = true) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index e12e272aedff..33b602907093 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -44,6 +44,8 @@ object TestRelations { AttributeReference("g", StringType)(), AttributeReference("h", MapType(IntegerType, IntegerType))()) + val testRelation5 = LocalRelation(AttributeReference("i", StringType)()) + val nestedRelation = LocalRelation( AttributeReference("top", StructType( StructField("duplicateField", StringType) :: From 5b2b27266f8e34cb2e6e51ebd8a13a2e2c45f8f8 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 4 Feb 2019 11:37:13 +0900 Subject: [PATCH 10/11] Fix --- .../catalyst/analysis/ResolveHintsSuite.scala | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 1014c49cf8d2..aa5a51cbbf75 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -161,28 +161,46 @@ class ResolveHintsSuite extends AnalysisTest { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), - Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), - ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None), + Join( + ResolvedHint(testRelation, HintInfo(broadcast = true)), + ResolvedHint(testRelation2, HintInfo(broadcast = true)), + Inner, + None, + JoinHint(None, None)), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("TaBlE", "table2"), table("TaBlE").join(table("TaBlE2"))), - Join(ResolvedHint(testRelation, HintInfo(broadcast = true)), testRelation2, Inner, None), + Join( + ResolvedHint(testRelation, HintInfo(broadcast = true)), + testRelation2, + Inner, + None, + JoinHint(None, None)), caseSensitive = true) // global temp table checkAnalysis( UnresolvedHint("MAPJOIN", Seq("global_temp.table4", "GlOBal_TeMP.table5"), table("global_temp", "table4").join(table("global_temp", "table5"))), - Join(ResolvedHint(testRelation4, HintInfo(broadcast = true)), - ResolvedHint(testRelation5, HintInfo(broadcast = true)), Inner, None), + Join( + ResolvedHint(testRelation4, HintInfo(broadcast = true)), + ResolvedHint(testRelation5, HintInfo(broadcast = true)), + Inner, + None, + JoinHint(None, None)), caseSensitive = false) checkAnalysis( UnresolvedHint("MAPJOIN", Seq("global_temp.TaBlE4", "table5"), table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))), - Join(ResolvedHint(testRelation4, HintInfo(broadcast = true)), testRelation5, Inner, None), + Join( + ResolvedHint(testRelation4, HintInfo(broadcast = true)), + testRelation5, + Inner, + None, + JoinHint(None, None)), caseSensitive = true) } } From b6b9f656ea04d7adb7036a7e276ce997efdf446e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Wed, 13 Feb 2019 23:03:23 +0900 Subject: [PATCH 11/11] Fix --- .../sql/catalyst/analysis/ResolveHints.scala | 37 +++++------- .../catalyst/analysis/ResolveHintsSuite.scala | 2 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 60 +++++++++++++++---- .../sql/execution/GlobalTempViewSuite.scala | 9 +-- .../spark/sql/execution/SQLViewSuite.scala | 4 +- 5 files changed, 70 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 1551a48cedf1..05c7c7c9dc0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -54,30 +54,23 @@ object ResolveHints { def resolver: Resolver = conf.resolver - private def namePartsWithDatabase(nameParts: Seq[String], database: String): Seq[String] = { - if (nameParts.size == 1) { - database +: nameParts - } else { - nameParts - } - } - + // Name resolution in hints follows three rules below: + // + // 1. table name matches if the hint table name only has one part + // 2. table name and database name both match if the hint table name has two parts + // 3. no match happens if the hint table name has more than three parts + // + // This means, `SELECT /* BROADCAST(t) */ * FROM db1.t JOIN db2.t` will match both tables, and + // `SELECT /* BROADCAST(default.t) */ * FROM t` match no table. private def matchedTableIdentifier( nameParts: Seq[String], - tableIdent: IdentifierWithDatabase): Boolean = { - tableIdent.database match { - case Some(db) if resolver(catalog.globalTempViewManager.database, db) => - val identifierList = db :: tableIdent.identifier :: Nil - namePartsWithDatabase(nameParts, catalog.globalTempViewManager.database) - .corresponds(identifierList)(resolver) - case None if catalog.getTempView(tableIdent.identifier).isDefined => - nameParts.size == 1 && resolver(nameParts.head, tableIdent.identifier) - case _ => - val db = tableIdent.database.getOrElse(catalog.getCurrentDatabase) - val identifierList = db :: tableIdent.identifier :: Nil - namePartsWithDatabase(nameParts, catalog.getCurrentDatabase) - .corresponds(identifierList)(resolver) - } + tableIdent: IdentifierWithDatabase): Boolean = nameParts match { + case Seq(tableName) => + resolver(tableIdent.identifier, tableName) + case Seq(dbName, tableName) if tableIdent.database.isDefined => + resolver(tableIdent.database.get, dbName) && resolver(tableIdent.identifier, tableName) + case _ => + false } private def applyBroadcastHint( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index aa5a51cbbf75..a5855447e603 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -156,7 +156,7 @@ class ResolveHintsSuite extends AnalysisTest { Seq(errMsgRepa)) } - test("Supports multi-part table names for broadcast hint resolution") { + test("supports multi-part table names for broadcast hint resolution") { // local temp table checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table", "table2"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 8aa66dc3f9fd..b8b4f3b6a6fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.execution.FileSourceScanExec -import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.functions._ @@ -195,43 +194,78 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1) } - test("SPARK-25121 Supports multi-part names for broadcast hint resolution") { + test("SPARK-25121 supports multi-part names for broadcast hint resolution") { val (table1Name, table2Name) = ("t1", "t2") + withTempDatabase { dbName => withTable(table1Name, table2Name) { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { spark.range(50).write.saveAsTable(s"$dbName.$table1Name") spark.range(100).write.saveAsTable(s"$dbName.$table2Name") + // First, makes sure a join is not broadcastable val plan = sql(s"SELECT * FROM $dbName.$table1Name, $dbName.$table2Name " + s"WHERE $table1Name.id = $table2Name.id") .queryExecution.executedPlan - assert(plan.collect { case p: BroadcastHashJoinExec => p }.size == 0) + assert(plan.collect { case p: BroadcastHashJoinExec => p }.isEmpty) - // Uses multi-part table names for broadcast hints def checkIfHintApplied(tableName: String, hintTableName: String): Unit = { val p = sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " + s"FROM $tableName, $dbName.$table2Name " + s"WHERE $tableName.id = $table2Name.id") .queryExecution.executedPlan - val broadcastHashJoin = p.collect { case p: BroadcastHashJoinExec => p } - assert(broadcastHashJoin.size == 1) - val broadcastExchange = broadcastHashJoin.head.collect { + val broadcastHashJoins = p.collect { case p: BroadcastHashJoinExec => p } + assert(broadcastHashJoins.size == 1) + val broadcastExchanges = broadcastHashJoins.head.collect { case p: BroadcastExchangeExec => p } - assert(broadcastExchange.size == 1) - val table = broadcastExchange.head.collect { + assert(broadcastExchanges.size == 1) + val tables = broadcastExchanges.head.collect { case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent } - assert(table.size == 1) - assert(table.head === TableIdentifier(table1Name, Some(dbName))) + assert(tables.size == 1) + assert(tables.head === TableIdentifier(table1Name, Some(dbName))) + } + + def checkIfHintNotApplied(tableName: String, hintTableName: String): Unit = { + val p = sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " + + s"FROM $tableName, $dbName.$table2Name " + + s"WHERE $tableName.id = $table2Name.id") + .queryExecution.executedPlan + val broadcastHashJoins = p.collect { case p: BroadcastHashJoinExec => p } + assert(broadcastHashJoins.isEmpty) } sql(s"USE $dbName") checkIfHintApplied(table1Name, table1Name) checkIfHintApplied(s"$dbName.$table1Name", s"$dbName.$table1Name") - checkIfHintApplied(table1Name, s"$dbName.$table1Name") checkIfHintApplied(s"$dbName.$table1Name", table1Name) + checkIfHintNotApplied(table1Name, s"$dbName.$table1Name") + checkIfHintNotApplied(s"$dbName.$table1Name", s"$dbName.$table1Name.id") + } + } + } + } + + test("SPARK-25121 the same table name exists in two databases for broadcast hint resolution") { + val (db1Name, db2Name) = ("db1", "db2") + + withDatabase(db1Name, db2Name) { + withTable("t") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + sql(s"CREATE DATABASE $db1Name") + sql(s"CREATE DATABASE $db2Name") + spark.range(1).write.saveAsTable(s"$db1Name.t") + spark.range(1).write.saveAsTable(s"$db2Name.t") + + // Checks if a broadcast hint applied in both sides + val statement = s"SELECT /*+ BROADCASTJOIN(t) */ * FROM $db1Name.t, $db2Name.t " + + s"WHERE $db1Name.t.id = $db2Name.t.id" + sql(statement).queryExecution.optimizedPlan match { + case Join(_, _, _, _, JoinHint(Some(leftHint), Some(rightHint))) => + assert(leftHint.broadcast && rightHint.broadcast) + case _ => fail("broadcast hint not found in both tables") + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index 97c0c4c432ed..bd85b4b1d2db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalog.Table import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{Join, ResolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -169,9 +169,10 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { "SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id", "SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id" ).foreach { statement => - val plan = sql(statement).queryExecution.optimizedPlan - assert(plan.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) - assert(!plan.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + sql(statement).queryExecution.optimizedPlan match { + case Join(_, _, _, _, JoinHint(Some(leftHint), None)) => assert(leftHint.broadcast) + case _ => fail("broadcast hint not found in a left-side table") + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 9d56b6c1f1ee..257cac4d05c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -738,8 +738,8 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { assert(broadcastData.head.identifier === "tv") val sparkPlan = df.queryExecution.executedPlan - val broadcastHashJoin = sparkPlan.collect { case p: BroadcastHashJoinExec => p } - assert(broadcastHashJoin.size == 1) + val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p } + assert(broadcastHashJoins.size == 1) } } }