diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 46eca4a1d480..b913a9618d6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -125,9 +125,9 @@ class Analyzer( maxIterations: Int) extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog { - private val catalog: SessionCatalog = catalogManager.v1SessionCatalog + private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog - override def isView(nameParts: Seq[String]): Boolean = catalog.isView(nameParts) + override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts) // Only for tests. def this(catalog: SessionCatalog, conf: SQLConf) = { @@ -225,7 +225,7 @@ class Analyzer( ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: - ResolveHigherOrderFunctions(catalog) :: + ResolveHigherOrderFunctions(v1SessionCatalog) :: ResolveLambdaVariables(conf) :: ResolveTimeZone(conf) :: ResolveRandomSeed :: @@ -721,7 +721,7 @@ class Analyzer( // have empty defaultDatabase and all the relations in viewText have database part defined. def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match { case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident)) - if catalog.isTemporaryTable(ident) => + if v1SessionCatalog.isTemporaryTable(ident) => resolveRelation(lookupTableFromCatalog(ident, u, AnalysisContext.get.defaultDatabase)) case u @ UnresolvedRelation(AsTableIdentifier(ident)) if !isRunningDirectlyOnFiles(ident) => @@ -778,7 +778,7 @@ class Analyzer( val tableIdentWithDb = tableIdentifier.copy( database = tableIdentifier.database.orElse(defaultDatabase)) try { - catalog.lookupRelation(tableIdentWithDb) + v1SessionCatalog.lookupRelation(tableIdentWithDb) } catch { case _: NoSuchTableException | _: NoSuchDatabaseException => u @@ -792,8 +792,9 @@ class Analyzer( // Note that we are testing (!db_exists || !table_exists) because the catalog throws // an exception from tableExists if the database does not exist. private def isRunningDirectlyOnFiles(table: TableIdentifier): Boolean = { - table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) && - (!catalog.databaseExists(table.database.get) || !catalog.tableExists(table)) + table.database.isDefined && conf.runSQLonFile && !v1SessionCatalog.isTemporaryTable(table) && + (!v1SessionCatalog.databaseExists(table.database.get) + || !v1SessionCatalog.tableExists(table)) } } @@ -1511,13 +1512,14 @@ class Analyzer( plan.resolveExpressions { case f: UnresolvedFunction if externalFunctionNameSet.contains(normalizeFuncName(f.name)) => f - case f: UnresolvedFunction if catalog.isRegisteredFunction(f.name) => f - case f: UnresolvedFunction if catalog.isPersistentFunction(f.name) => + case f: UnresolvedFunction if v1SessionCatalog.isRegisteredFunction(f.name) => f + case f: UnresolvedFunction if v1SessionCatalog.isPersistentFunction(f.name) => externalFunctionNameSet.add(normalizeFuncName(f.name)) f case f: UnresolvedFunction => withPosition(f) { - throw new NoSuchFunctionException(f.name.database.getOrElse(catalog.getCurrentDatabase), + throw new NoSuchFunctionException( + f.name.database.getOrElse(v1SessionCatalog.getCurrentDatabase), f.name.funcName) } } @@ -1532,7 +1534,7 @@ class Analyzer( val databaseName = name.database match { case Some(a) => formatDatabaseName(a) - case None => catalog.getCurrentDatabase + case None => v1SessionCatalog.getCurrentDatabase } FunctionIdentifier(funcName, Some(databaseName)) @@ -1557,7 +1559,7 @@ class Analyzer( } case u @ UnresolvedGenerator(name, children) => withPosition(u) { - catalog.lookupFunction(name, children) match { + v1SessionCatalog.lookupFunction(name, children) match { case generator: Generator => generator case other => failAnalysis(s"$name is expected to be a generator. However, " + @@ -1566,7 +1568,7 @@ class Analyzer( } case u @ UnresolvedFunction(funcId, children, isDistinct) => withPosition(u) { - catalog.lookupFunction(funcId, children) match { + v1SessionCatalog.lookupFunction(funcId, children) match { // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. @@ -2765,17 +2767,17 @@ class Analyzer( private def lookupV2RelationAndCatalog( identifier: Seq[String]): Option[(DataSourceV2Relation, CatalogPlugin, Identifier)] = identifier match { - case AsTemporaryViewIdentifier(ti) if catalog.isTemporaryTable(ti) => None - case CatalogObjectIdentifier(Some(v2Catalog), ident) => - CatalogV2Util.loadTable(v2Catalog, ident) match { - case Some(table) => Some((DataSourceV2Relation.create(table), v2Catalog, ident)) + case AsTemporaryViewIdentifier(ti) if v1SessionCatalog.isTemporaryTable(ti) => None + case CatalogObjectIdentifier(catalog, ident) if !CatalogV2Util.isSessionCatalog(catalog) => + CatalogV2Util.loadTable(catalog, ident) match { + case Some(table) => Some((DataSourceV2Relation.create(table), catalog, ident)) case None => None } - case CatalogObjectIdentifier(None, ident) => - CatalogV2Util.loadTable(catalogManager.v2SessionCatalog, ident) match { + case CatalogObjectIdentifier(catalog, ident) if CatalogV2Util.isSessionCatalog(catalog) => + CatalogV2Util.loadTable(catalog, ident) match { case Some(_: V1Table) => None case Some(table) => - Some((DataSourceV2Relation.create(table), catalogManager.v2SessionCatalog, ident)) + Some((DataSourceV2Relation.create(table), catalog, ident)) case None => None } case _ => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 568944678544..13a79a82a385 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -177,9 +177,8 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case ShowTablesStatement(Some(NonSessionCatalog(catalog, nameParts)), pattern) => ShowTables(catalog.asTableCatalog, nameParts, pattern) - // TODO (SPARK-29014): we should check if the current catalog is not session catalog here. - case ShowTablesStatement(None, pattern) if defaultCatalog.isDefined => - ShowTables(defaultCatalog.get.asTableCatalog, catalogManager.currentNamespace, pattern) + case ShowTablesStatement(None, pattern) if !isSessionCatalog(currentCatalog) => + ShowTables(currentCatalog.asTableCatalog, catalogManager.currentNamespace, pattern) case UseStatement(isNamespaceSet, nameParts) => if (isNamespaceSet) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala index 14ccfd5bfcc9..c9d050768c15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogManager.scala @@ -53,7 +53,7 @@ class CatalogManager( } } - def defaultCatalog: Option[CatalogPlugin] = { + private def defaultCatalog: Option[CatalogPlugin] = { conf.defaultV2Catalog.flatMap { catalogName => try { Some(catalog(catalogName)) @@ -74,9 +74,16 @@ class CatalogManager( } } - // If the V2_SESSION_CATALOG_IMPLEMENTATION config is specified, we try to instantiate the - // user-specified v2 session catalog. Otherwise, return the default session catalog. - def v2SessionCatalog: CatalogPlugin = { + /** + * If the V2_SESSION_CATALOG config is specified, we try to instantiate the user-specified v2 + * session catalog. Otherwise, return the default session catalog. + * + * This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the + * session catalog is responsible for an identifier, but the source requires the v2 catalog API. + * This happens when the source implementation extends the v2 TableProvider API and is not listed + * in the fallback configuration, spark.sql.sources.write.useV1SourceList + */ + private def v2SessionCatalog: CatalogPlugin = { conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).map { customV2SessionCatalog => try { catalogs.getOrElseUpdate(SESSION_CATALOG_NAME, loadV2SessionCatalog()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala index 02585fd5c463..26ba93e57fc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/LookupCatalog.scala @@ -27,29 +27,11 @@ private[sql] trait LookupCatalog extends Logging { protected val catalogManager: CatalogManager - /** - * Returns the default catalog. When set, this catalog is used for all identifiers that do not - * set a specific catalog. When this is None, the session catalog is responsible for the - * identifier. - * - * If this is None and a table's provider (source) is a v2 provider, the v2 session catalog will - * be used. - */ - def defaultCatalog: Option[CatalogPlugin] = catalogManager.defaultCatalog - /** * Returns the current catalog set. */ def currentCatalog: CatalogPlugin = catalogManager.currentCatalog - /** - * This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the - * session catalog is responsible for an identifier, but the source requires the v2 catalog API. - * This happens when the source implementation extends the v2 TableProvider API and is not listed - * in the fallback configuration, spark.sql.sources.write.useV1SourceList - */ - def sessionCatalog: CatalogPlugin = catalogManager.v2SessionCatalog - /** * Extract catalog plugin and remaining identifier names. * @@ -69,16 +51,14 @@ private[sql] trait LookupCatalog extends Logging { } } - type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier) - /** - * Extract catalog and identifier from a multi-part identifier with the default catalog if needed. + * Extract catalog and identifier from a multi-part identifier with the current catalog if needed. */ object CatalogObjectIdentifier { - def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match { + def unapply(parts: Seq[String]): Some[(CatalogPlugin, Identifier)] = parts match { case CatalogAndIdentifier(maybeCatalog, nameParts) => Some(( - maybeCatalog.orElse(defaultCatalog), + maybeCatalog.getOrElse(currentCatalog), Identifier.of(nameParts.init.toArray, nameParts.last) )) } @@ -108,7 +88,7 @@ private[sql] trait LookupCatalog extends Logging { */ object AsTableIdentifier { def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match { - case CatalogAndIdentifier(None, names) if defaultCatalog.isEmpty => + case CatalogAndIdentifier(None, names) if CatalogV2Util.isSessionCatalog(currentCatalog) => names match { case Seq(name) => Some(TableIdentifier(name)) @@ -146,8 +126,7 @@ private[sql] trait LookupCatalog extends Logging { Some((catalogManager.catalog(nameParts.head), nameParts.tail)) } catch { case _: CatalogNotFoundException => - // TODO (SPARK-29014): use current catalog here. - Some((defaultCatalog.getOrElse(sessionCatalog), nameParts)) + Some((currentCatalog, nameParts)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala index c4a28bb6420c..513f7e0348d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/LookupCatalogSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.FakeV2SessionCatalog import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -36,6 +37,7 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { import CatalystSqlParser._ private val catalogs = Seq("prod", "test").map(x => x -> DummyCatalogPlugin(x)).toMap + private val sessionCatalog = FakeV2SessionCatalog override val catalogManager: CatalogManager = { val manager = mock(classOf[CatalogManager]) @@ -43,22 +45,22 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { val name = invocation.getArgument[String](0) catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found")) }) - when(manager.defaultCatalog).thenReturn(None) + when(manager.currentCatalog).thenReturn(sessionCatalog) manager } test("catalog object identifier") { Seq( - ("tbl", None, Seq.empty, "tbl"), - ("db.tbl", None, Seq("db"), "tbl"), - ("prod.func", catalogs.get("prod"), Seq.empty, "func"), - ("ns1.ns2.tbl", None, Seq("ns1", "ns2"), "tbl"), - ("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"), - ("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"), - ("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"), - ("`db.tbl`", None, Seq.empty, "db.tbl"), - ("parquet.`file:/tmp/db.tbl`", None, Seq("parquet"), "file:/tmp/db.tbl"), - ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None, + ("tbl", sessionCatalog, Seq.empty, "tbl"), + ("db.tbl", sessionCatalog, Seq("db"), "tbl"), + ("prod.func", catalogs("prod"), Seq.empty, "func"), + ("ns1.ns2.tbl", sessionCatalog, Seq("ns1", "ns2"), "tbl"), + ("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"), + ("test.db.tbl", catalogs("test"), Seq("db"), "tbl"), + ("test.ns1.ns2.ns3.tbl", catalogs("test"), Seq("ns1", "ns2", "ns3"), "tbl"), + ("`db.tbl`", sessionCatalog, Seq.empty, "db.tbl"), + ("parquet.`file:/tmp/db.tbl`", sessionCatalog, Seq("parquet"), "file:/tmp/db.tbl"), + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", sessionCatalog, Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach { case (sql, expectedCatalog, namespace, name) => inside(parseMultipartIdentifier(sql)) { @@ -135,22 +137,22 @@ class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog wit val name = invocation.getArgument[String](0) catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found")) }) - when(manager.defaultCatalog).thenReturn(catalogs.get("prod")) + when(manager.currentCatalog).thenReturn(catalogs("prod")) manager } test("catalog object identifier") { Seq( - ("tbl", catalogs.get("prod"), Seq.empty, "tbl"), - ("db.tbl", catalogs.get("prod"), Seq("db"), "tbl"), - ("prod.func", catalogs.get("prod"), Seq.empty, "func"), - ("ns1.ns2.tbl", catalogs.get("prod"), Seq("ns1", "ns2"), "tbl"), - ("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"), - ("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"), - ("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"), - ("`db.tbl`", catalogs.get("prod"), Seq.empty, "db.tbl"), - ("parquet.`file:/tmp/db.tbl`", catalogs.get("prod"), Seq("parquet"), "file:/tmp/db.tbl"), - ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", catalogs.get("prod"), + ("tbl", catalogs("prod"), Seq.empty, "tbl"), + ("db.tbl", catalogs("prod"), Seq("db"), "tbl"), + ("prod.func", catalogs("prod"), Seq.empty, "func"), + ("ns1.ns2.tbl", catalogs("prod"), Seq("ns1", "ns2"), "tbl"), + ("prod.db.tbl", catalogs("prod"), Seq("db"), "tbl"), + ("test.db.tbl", catalogs("test"), Seq("db"), "tbl"), + ("test.ns1.ns2.ns3.tbl", catalogs("test"), Seq("ns1", "ns2", "ns3"), "tbl"), + ("`db.tbl`", catalogs("prod"), Seq.empty, "db.tbl"), + ("parquet.`file:/tmp/db.tbl`", catalogs("prod"), Seq("parquet"), "file:/tmp/db.tbl"), + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", catalogs("prod"), Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach { case (sql, expectedCatalog, namespace, name) => inside(parseMultipartIdentifier(sql)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3d04a0616e96..4f88cc6daa33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -341,6 +341,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { def insertInto(tableName: String): Unit = { import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + import org.apache.spark.sql.connector.catalog.CatalogV2Util._ assertNotBucketed("insertInto") @@ -354,14 +355,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val session = df.sparkSession val canUseV2 = lookupV2Provider().isDefined - val sessionCatalog = session.sessionState.analyzer.sessionCatalog session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { - case CatalogObjectIdentifier(Some(catalog), ident) => + case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) => insertInto(catalog, ident) - case CatalogObjectIdentifier(None, ident) if canUseV2 && ident.namespace().length <= 1 => - insertInto(sessionCatalog, ident) + case CatalogObjectIdentifier(catalog, ident) + if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 => + insertInto(catalog, ident) case AsTableIdentifier(tableIdentifier) => insertInto(tableIdentifier) @@ -480,17 +481,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { def saveAsTable(tableName: String): Unit = { import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + import org.apache.spark.sql.connector.catalog.CatalogV2Util._ val session = df.sparkSession val canUseV2 = lookupV2Provider().isDefined - val sessionCatalog = session.sessionState.analyzer.sessionCatalog session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { - case CatalogObjectIdentifier(Some(catalog), ident) => + case CatalogObjectIdentifier(catalog, ident) if !isSessionCatalog(catalog) => saveAsTable(catalog.asTableCatalog, ident) - case CatalogObjectIdentifier(None, ident) if canUseV2 && ident.namespace().length <= 1 => - saveAsTable(sessionCatalog.asTableCatalog, ident) + case CatalogObjectIdentifier(catalog, ident) + if isSessionCatalog(catalog) && canUseV2 && ident.namespace().length <= 1 => + saveAsTable(catalog.asTableCatalog, ident) case AsTableIdentifier(tableIdentifier) => saveAsTable(tableIdentifier) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 17782e8ab1f0..9d3ce6fde20a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -51,9 +51,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) private val (catalog, identifier) = { - val CatalogObjectIdentifier(maybeCatalog, identifier) = tableName - val catalog = maybeCatalog.getOrElse(catalogManager.currentCatalog).asTableCatalog - (catalog, identifier) + val CatalogObjectIdentifier(catalog, identifier) = tableName + (catalog.asTableCatalog, identifier) } private val logicalPlan = df.queryExecution.logical diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index a96533dac97e..d5936891476f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -262,8 +262,7 @@ class ResolveSessionCatalog( } ShowTablesCommand(Some(nameParts.head), pattern) - // TODO (SPARK-29014): we should check if the current catalog is session catalog here. - case ShowTablesStatement(None, pattern) if defaultCatalog.isEmpty => + case ShowTablesStatement(None, pattern) if isSessionCatalog(currentCatalog) => ShowTablesCommand(None, pattern) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index e27575cecde2..08627e681f9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -84,7 +84,7 @@ class DataSourceV2DataFrameSessionCatalogSuite val t1 = "prop_table" withTable(t1) { spark.range(20).write.format(v2Format).option("path", "abc").saveAsTable(t1) - val cat = spark.sessionState.catalogManager.v2SessionCatalog.asInstanceOf[TableCatalog] + val cat = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] val tableInfo = cat.loadTable(Identifier.of(Array.empty, t1)) assert(tableInfo.properties().get("location") === "abc") assert(tableInfo.properties().get("provider") === v2Format) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala index e1a5dbe3351e..27725bcadbcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSessionCatalogSuite.scala @@ -44,7 +44,7 @@ class DataSourceV2SQLSessionCatalogSuite } override def getTableMetadata(tableName: String): Table = { - val v2Catalog = spark.sessionState.catalogManager.v2SessionCatalog + val v2Catalog = spark.sessionState.catalogManager.currentCatalog val nameParts = spark.sessionState.sqlParser.parseMultipartIdentifier(tableName) v2Catalog.asInstanceOf[TableCatalog] .loadTable(Identifier.of(Array.empty, nameParts.last)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 3b42c2374f00..5e7e81b88970 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -727,6 +727,23 @@ class DataSourceV2SQLSuite expectV2Catalog = false) } + test("ShowTables: change current catalog and namespace with USE statements") { + sql("CREATE TABLE testcat.ns1.ns2.table (id bigint) USING foo") + + // Initially, the v2 session catalog (current catalog) is used. + runShowTablesSql( + "SHOW TABLES", Seq(Row("", "source", true), Row("", "source2", true)), + expectV2Catalog = false) + + // Update the current catalog, and no table is matched since the current namespace is Array(). + sql("USE testcat") + runShowTablesSql("SHOW TABLES", Seq()) + + // Update the current namespace to match ns1.ns2.table. + sql("USE testcat.ns1.ns2") + runShowTablesSql("SHOW TABLES", Seq(Row("ns1.ns2", "table"))) + } + private def runShowTablesSql( sqlText: String, expected: Seq[Row], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 674efa9b8ba4..562e61390a53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -96,8 +96,7 @@ class PlanResolutionSuite extends AnalysisTest { throw new CatalogNotFoundException(s"No such catalog: $name") } }) - when(manager.defaultCatalog).thenReturn(Some(testCat)) - when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog) + when(manager.currentCatalog).thenReturn(testCat) when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog) manager } @@ -112,8 +111,7 @@ class PlanResolutionSuite extends AnalysisTest { throw new CatalogNotFoundException(s"No such catalog: $name") } }) - when(manager.defaultCatalog).thenReturn(None) - when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog) + when(manager.currentCatalog).thenReturn(v2SessionCatalog) when(manager.v1SessionCatalog).thenReturn(v1SessionCatalog) manager }