diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogManager.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogManager.scala new file mode 100644 index 000000000000..eda7d3f45dc7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogManager.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2 + +import scala.collection.mutable + +import org.apache.spark.sql.internal.SQLConf + +/** + * A thread-safe manager for [[CatalogPlugin]]s. It tracks all the registered catalogs, and allow + * the caller to look up a catalog by name. + */ +class CatalogManager(conf: SQLConf, val v2SessionCatalog: TableCatalog) { + + /** + * Tracks all the registered catalogs. + */ + private val catalogs = mutable.HashMap.empty[String, CatalogPlugin] + + /** + * Looks up a catalog by name. + */ + def getCatalog(name: String): CatalogPlugin = synchronized { + catalogs.getOrElseUpdate(name, Catalogs.load(name, conf)) + } + + /** + * Returns the default catalog specified by config. + */ + def getDefaultCatalog(): Option[CatalogPlugin] = { + conf.defaultV2Catalog.map(getCatalog) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala index 5f7ee30cdab7..f95818721fa4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalog.v2 -import scala.util.control.NonFatal - import org.apache.spark.annotation.Experimental import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.TableIdentifier @@ -29,10 +27,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier @Experimental trait LookupCatalog extends Logging { - import LookupCatalog._ - - protected def defaultCatalogName: Option[String] = None - protected def lookupCatalog(name: String): CatalogPlugin + protected val catalogManager: CatalogManager /** * Returns the default catalog. When set, this catalog is used for all identifiers that do not @@ -42,15 +37,7 @@ trait LookupCatalog extends Logging { * If this is None and a table's provider (source) is a v2 provider, the v2 session catalog will * be used. */ - def defaultCatalog: Option[CatalogPlugin] = { - try { - defaultCatalogName.map(lookupCatalog) - } catch { - case NonFatal(e) => - logError(s"Cannot load default v2 catalog: ${defaultCatalogName.get}", e) - None - } - } + def defaultCatalog: Option[CatalogPlugin] = catalogManager.getDefaultCatalog() /** * This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the @@ -58,15 +45,7 @@ trait LookupCatalog extends Logging { * This happens when the source implementation extends the v2 TableProvider API and is not listed * in the fallback configuration, spark.sql.sources.write.useV1SourceList */ - def sessionCatalog: Option[CatalogPlugin] = { - try { - Some(lookupCatalog(SESSION_CATALOG_NAME)) - } catch { - case NonFatal(e) => - logError("Cannot load v2 session catalog", e) - None - } - } + def sessionCatalog: TableCatalog = catalogManager.v2SessionCatalog /** * Extract catalog plugin and remaining identifier names. @@ -79,7 +58,7 @@ trait LookupCatalog extends Logging { Some((None, parts)) case Seq(catalogName, tail @ _*) => try { - Some((Some(lookupCatalog(catalogName)), tail)) + Some((Some(catalogManager.getCatalog(catalogName)), tail)) } catch { case _: CatalogNotFoundException => Some((None, parts)) @@ -137,7 +116,3 @@ trait LookupCatalog extends Logging { } } } - -object LookupCatalog { - val SESSION_CATALOG_NAME: String = "session" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1d0dba262c10..d61ee6db274c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import java.util import java.util.Locale import scala.collection.mutable @@ -24,7 +25,8 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog} +import org.apache.spark.sql.catalog.v2.{CatalogManager, Identifier, LookupCatalog, TableCatalog, TableChange} +import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes @@ -39,7 +41,9 @@ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.Table import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]]. @@ -55,6 +59,27 @@ object SimpleAnalyzer extends Analyzer( }, new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) +// For test only +class NoopV2SessionCatalog extends TableCatalog { + override def listTables(namespace: Array[String]): Array[Identifier] = Array.empty + override def loadTable(ident: Identifier): Table = { + throw new UnsupportedOperationException + } + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + throw new UnsupportedOperationException + } + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + throw new UnsupportedOperationException + } + override def dropTable(ident: Identifier): Boolean = false + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {} + override def name(): String = "noop" +} + /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns * of analysis environment from the catalog. @@ -96,18 +121,15 @@ object AnalysisContext { */ class Analyzer( catalog: SessionCatalog, - conf: SQLConf, - maxIterations: Int) + val catalogManager: CatalogManager, + conf: SQLConf) extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog { def this(catalog: SessionCatalog, conf: SQLConf) = { - this(catalog, conf, conf.optimizerMaxIterations) + this(catalog, new CatalogManager(conf, new NoopV2SessionCatalog), conf) } - override protected def defaultCatalogName: Option[String] = conf.defaultV2Catalog - - override protected def lookupCatalog(name: String): CatalogPlugin = - throw new CatalogNotFoundException("No catalog lookup function") + private val maxIterations: Int = conf.optimizerMaxIterations def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { AnalysisHelper.markInAnalyzer { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e2636d27e353..af67632706df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1833,11 +1833,6 @@ object SQLConf { .stringConf .createOptional - val V2_SESSION_CATALOG = buildConf("spark.sql.catalog.session") - .doc("Name of the default v2 catalog, used when a catalog is not identified in queries") - .stringConf - .createWithDefault("org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog") - val LEGACY_LOOSE_UPCAST = buildConf("spark.sql.legacy.looseUpcast") .doc("When true, the upcast will be loose and allows string to atomic types.") .booleanConf diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala index 52543d16d481..495b962c8fea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala @@ -16,11 +16,14 @@ */ package org.apache.spark.sql.catalyst.catalog.v2 +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock import org.scalatest.Inside import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, LookupCatalog} +import org.apache.spark.sql.catalog.v2.{CatalogManager, CatalogNotFoundException, CatalogPlugin, Identifier, LookupCatalog} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -35,8 +38,15 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { private val catalogs = Seq("prod", "test").map(x => x -> new TestCatalogPlugin(x)).toMap - override def lookupCatalog(name: String): CatalogPlugin = - catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found")) + override val catalogManager: CatalogManager = { + val manager = mock(classOf[CatalogManager]) + when(manager.getCatalog(any())).thenAnswer((invocation: InvocationOnMock) => { + val name = invocation.getArgument[String](0) + catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found")) + }) + when(manager.getDefaultCatalog()).thenReturn(None) + manager + } test("catalog object identifier") { Seq( @@ -120,10 +130,15 @@ class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog wit private val catalogs = Seq("prod", "test").map(x => x -> new TestCatalogPlugin(x)).toMap - override def defaultCatalogName: Option[String] = Some("prod") - - override def lookupCatalog(name: String): CatalogPlugin = - catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found")) + override val catalogManager: CatalogManager = { + val manager = mock(classOf[CatalogManager]) + when(manager.getCatalog(any())).thenAnswer((invocation: InvocationOnMock) => { + val name = invocation.getArgument[String](0) + catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found")) + }) + when(manager.getDefaultCatalog()).thenReturn(catalogs.get("prod")) + manager + } test("catalog object identifier") { Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 90d1b9205787..e0d0062e976c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -607,12 +607,6 @@ class SparkSession private( */ @transient lazy val catalog: Catalog = new CatalogImpl(self) - @transient private lazy val catalogs = new mutable.HashMap[String, CatalogPlugin]() - - private[sql] def catalog(name: String): CatalogPlugin = synchronized { - catalogs.getOrElseUpdate(name, Catalogs.load(name, sessionState.conf)) - } - /** * Returns the specified table/view as a `DataFrame`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 1b7bb169b36f..ffbe445e17ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog, TableCatalog} +import org.apache.spark.sql.catalog.v2.{Identifier, LookupCatalog, TableCatalog} import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.CastSupport @@ -44,9 +44,6 @@ case class DataSourceResolution( import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ import lookup._ - lazy val v2SessionCatalog: CatalogPlugin = lookup.sessionCatalog - .getOrElse(throw new AnalysisException("No v2 session catalog implementation is available")) - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case CreateTableStatement( AsTableIdentifier(table), schema, partitionCols, bucketSpec, properties, @@ -68,7 +65,7 @@ case class DataSourceResolution( case _ => // the identifier had no catalog and no default catalog is set, but the source is v2. // use the v2 session catalog, which delegates to the global v1 session catalog - convertCreateTable(v2SessionCatalog.asTableCatalog, identifier, create) + convertCreateTable(lookup.sessionCatalog, identifier, create) } case CreateTableAsSelectStatement( @@ -91,7 +88,7 @@ case class DataSourceResolution( case _ => // the identifier had no catalog and no default catalog is set, but the source is v2. // use the v2 session catalog, which delegates to the global v1 session catalog - convertCTAS(v2SessionCatalog.asTableCatalog, identifier, create) + convertCTAS(lookup.sessionCatalog, identifier, create) } case DropTableStatement(CatalogObjectIdentifier(Some(catalog), ident), ifExists, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 4cd0346b57e7..7ea28a1b1ea4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util -import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog, TableChange} import org.apache.spark.sql.catalog.v2.expressions.{BucketTransform, FieldReference, IdentityTransform, LogicalExpressions, Transform} import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util @@ -31,7 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.internal.SessionState +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.{Table, TableCapability} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -39,13 +37,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A [[TableCatalog]] that translates calls to the v1 SessionCatalog. */ -class V2SessionCatalog(sessionState: SessionState) extends TableCatalog { - def this() = { - this(SparkSession.active.sessionState) - } - - private lazy val catalog: SessionCatalog = sessionState.catalog - +class V2SessionCatalog(catalog: SessionCatalog, conf: SQLConf) extends TableCatalog { private var _name: String = _ override def name: String = _name @@ -85,7 +77,7 @@ class V2SessionCatalog(sessionState: SessionState) extends TableCatalog { properties: util.Map[String, String]): Table = { val (partitionColumns, maybeBucketSpec) = V2SessionCatalog.convertTransforms(partitions) - val provider = properties.getOrDefault("provider", sessionState.conf.defaultDataSourceName) + val provider = properties.getOrDefault("provider", conf.defaultDataSourceName) val tableProperties = properties.asScala val location = Option(properties.get("location")) val storage = DataSource.buildStorageFormatFromOptions(tableProperties.toMap) @@ -100,7 +92,7 @@ class V2SessionCatalog(sessionState: SessionState) extends TableCatalog { partitionColumnNames = partitionColumns, bucketSpec = maybeBucketSpec, properties = tableProperties.toMap, - tracksPartitionsInCatalog = sessionState.conf.manageFilesourcePartitions, + tracksPartitionsInCatalog = conf.manageFilesourcePartitions, comment = Option(properties.get("comment"))) try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index b05a5dfea3ff..46c1a6152de6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, Unstable} import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} -import org.apache.spark.sql.catalog.v2.CatalogPlugin +import org.apache.spark.sql.catalog.v2.{CatalogManager, CatalogPlugin} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck} +import org.apache.spark.sql.execution.datasources.v2.{V2SessionCatalog, V2StreamingScanSupportCheck, V2WriteSupportCheck} import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -152,6 +152,8 @@ abstract class BaseSessionStateBuilder( catalog } + protected lazy val catalogManager = new CatalogManager(conf, new V2SessionCatalog(catalog, conf)) + /** * Interface exposed to the user for registering user-defined functions. * @@ -165,7 +167,7 @@ abstract class BaseSessionStateBuilder( * * Note: this depends on the `conf` and `catalog` fields. */ - protected def analyzer: Analyzer = new Analyzer(catalog, conf) { + protected def analyzer: Analyzer = new Analyzer(catalog, catalogManager, conf) { override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: @@ -186,8 +188,6 @@ abstract class BaseSessionStateBuilder( V2WriteSupportCheck +: V2StreamingScanSupportCheck +: customCheckRules - - override protected def lookupCatalog(name: String): CatalogPlugin = session.catalog(name) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 7df0dabd67f8..58890672cfeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -20,8 +20,12 @@ package org.apache.spark.sql.execution.command import java.net.URI import java.util.Locale +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock + import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, LookupCatalog, TableCatalog, TestTableCatalog} +import org.apache.spark.sql.catalog.v2.{CatalogManager, CatalogNotFoundException, Identifier, LookupCatalog, TableCatalog, TestTableCatalog} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} @@ -50,28 +54,36 @@ class PlanResolutionSuite extends AnalysisTest { } private val lookupWithDefault: LookupCatalog = new LookupCatalog { - override protected def defaultCatalogName: Option[String] = Some("testcat") - - override protected def lookupCatalog(name: String): CatalogPlugin = name match { - case "testcat" => - testCat - case "session" => - v2SessionCatalog - case _ => - throw new CatalogNotFoundException(s"No such catalog: $name") + override protected val catalogManager = { + val manager = mock(classOf[CatalogManager]) + when(manager.getCatalog(any())).thenAnswer((invocation: InvocationOnMock) => { + invocation.getArgument[String](0) match { + case "testcat" => + testCat + case name => + throw new CatalogNotFoundException(s"No such catalog: $name") + } + }) + when(manager.getDefaultCatalog()).thenReturn(Some(testCat)) + when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog) + manager } } private val lookupWithoutDefault: LookupCatalog = new LookupCatalog { - override protected def defaultCatalogName: Option[String] = None - - override protected def lookupCatalog(name: String): CatalogPlugin = name match { - case "testcat" => - testCat - case "session" => - v2SessionCatalog - case _ => - throw new CatalogNotFoundException(s"No such catalog: $name") + override protected val catalogManager = { + val manager = mock(classOf[CatalogManager]) + when(manager.getCatalog(any())).thenAnswer((invocation: InvocationOnMock) => { + invocation.getArgument[String](0) match { + case "testcat" => + testCat + case name => + throw new CatalogNotFoundException(s"No such catalog: $name") + } + }) + when(manager.getDefaultCatalog()).thenReturn(None) + when(manager.v2SessionCatalog).thenReturn(v2SessionCatalog) + manager } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index 3822882cc91c..844703f0fc4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -62,23 +62,13 @@ class V2SessionCatalogSuite } private def newCatalog(): TableCatalog = { - val newCatalog = new V2SessionCatalog(spark.sessionState) + val newCatalog = new V2SessionCatalog(spark.sessionState.catalog, spark.sessionState.conf) newCatalog.initialize("test", CaseInsensitiveStringMap.empty()) newCatalog } private val testIdent = Identifier.of(Array("db"), "test_table") - test("Catalogs can load the catalog") { - val catalog = newCatalog() - - val conf = new SQLConf - conf.setConfString("spark.sql.catalog.test", catalog.getClass.getName) - - val loaded = Catalogs.load("test", conf) - assert(loaded.getClass == catalog.getClass) - } - test("listTables") { val catalog = newCatalog() val ident1 = Identifier.of(Array("ns"), "test_table_1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index 01752125ac26..2905bbbf1ea9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -31,14 +31,20 @@ import org.apache.spark.sql.types.{LongType, StringType, StructType} class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { - import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ - private val orc2 = classOf[OrcDataSourceV2].getName + private def getTestCatalog() = { + spark.sessionState.analyzer.catalogManager.getCatalog("testcat") + .asInstanceOf[TestInMemoryTableCatalog] + } + + private def getV2SessionCatalog() = { + spark.sessionState.analyzer.catalogManager.v2SessionCatalog + } + before { spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) spark.conf.set("spark.sql.catalog.testcat2", classOf[TestInMemoryTableCatalog].getName) - spark.conf.set("spark.sql.catalog.session", classOf[TestInMemoryTableCatalog].getName) val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") df.createOrReplaceTempView("source") @@ -47,14 +53,17 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn } after { - spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog].clearTables() - spark.catalog("session").asInstanceOf[TestInMemoryTableCatalog].clearTables() + getTestCatalog().clearTables() + spark.sharedState.externalCatalog.listTables("default").foreach { tblName => + spark.sql(s"DROP TABLE $tblName") + } + spark.sessionState.conf.clear() } test("CreateTable: use v2 plan because catalog is set") { spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") - val testCatalog = spark.catalog("testcat").asTableCatalog + val testCatalog = getTestCatalog() val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) assert(table.name == "testcat.table_name") @@ -69,23 +78,21 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn test("CreateTable: use v2 plan and session catalog when provider is v2") { spark.sql(s"CREATE TABLE table_name (id bigint, data string) USING $orc2") - val testCatalog = spark.catalog("session").asTableCatalog - val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + val v2SessionCatalog = getV2SessionCatalog() + val table = v2SessionCatalog.loadTable(Identifier.of(Array(), "table_name")) - assert(table.name == "session.table_name") + assert(table.name == "default.table_name") assert(table.partitioning.isEmpty) assert(table.properties == Map("provider" -> orc2).asJava) assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) - val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) - checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) + checkAnswer(spark.table("table_name"), Nil) } test("CreateTable: fail if table exists") { spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") - val testCatalog = spark.catalog("testcat").asTableCatalog - + val testCatalog = getTestCatalog() val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) assert(table.name == "testcat.table_name") assert(table.partitioning.isEmpty) @@ -115,7 +122,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn spark.sql( "CREATE TABLE IF NOT EXISTS testcat.table_name (id bigint, data string) USING foo") - val testCatalog = spark.catalog("testcat").asTableCatalog + val testCatalog = getTestCatalog() val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) assert(table.name == "testcat.table_name") @@ -137,13 +144,12 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), Seq.empty) } - test("CreateTable: use default catalog for v2 sources when default catalog is set") { - val sparkSession = spark.newSession() - sparkSession.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) - sparkSession.conf.set("spark.sql.default.catalog", "testcat") - sparkSession.sql(s"CREATE TABLE table_name (id bigint, data string) USING foo") + test("CreateTable: use default catalog when default catalog is set") { + spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) + spark.conf.set("spark.sql.default.catalog", "testcat") + spark.sql(s"CREATE TABLE table_name (id bigint, data string) USING foo") - val testCatalog = sparkSession.catalog("testcat").asTableCatalog + val testCatalog = getTestCatalog() val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) assert(table.name == "testcat.table_name") @@ -152,14 +158,14 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) // check that the table is empty - val rdd = sparkSession.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) } test("CreateTableAsSelect: use v2 plan because catalog is set") { spark.sql("CREATE TABLE testcat.table_name USING foo AS SELECT id, data FROM source") - val testCatalog = spark.catalog("testcat").asTableCatalog + val testCatalog = getTestCatalog() val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) assert(table.name == "testcat.table_name") @@ -174,26 +180,17 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn } test("CreateTableAsSelect: use v2 plan and session catalog when provider is v2") { - spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") - - val testCatalog = spark.catalog("session").asTableCatalog - val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) - - assert(table.name == "session.table_name") - assert(table.partitioning.isEmpty) - assert(table.properties == Map("provider" -> orc2).asJava) - assert(table.schema == new StructType() - .add("id", LongType, nullable = false) - .add("data", StringType)) - - val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) - checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) + // TODO: support write table from v2 session catalog. + val e = intercept[Exception] { + spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") + } + assert(e.getMessage.contains("Table implementation does not support writes: table_name")) } test("CreateTableAsSelect: fail if table exists") { spark.sql("CREATE TABLE testcat.table_name USING foo AS SELECT id, data FROM source") - val testCatalog = spark.catalog("testcat").asTableCatalog + val testCatalog = getTestCatalog() val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) assert(table.name == "testcat.table_name") @@ -231,7 +228,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn spark.sql( "CREATE TABLE IF NOT EXISTS testcat.table_name USING foo AS SELECT id, data FROM source") - val testCatalog = spark.catalog("testcat").asTableCatalog + val testCatalog = getTestCatalog() val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) assert(table.name == "testcat.table_name") @@ -253,18 +250,17 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn } test("CreateTableAsSelect: use default catalog for v2 sources when default catalog is set") { - val sparkSession = spark.newSession() - sparkSession.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) - sparkSession.conf.set("spark.sql.default.catalog", "testcat") + spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) + spark.conf.set("spark.sql.default.catalog", "testcat") - val df = sparkSession.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") + val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") df.createOrReplaceTempView("source") // setting the default catalog breaks the reference to source because the default catalog is // used and AsTableIdentifier no longer matches - sparkSession.sql(s"CREATE TABLE table_name USING foo AS SELECT id, data FROM source") + spark.sql(s"CREATE TABLE table_name USING foo AS SELECT id, data FROM source") - val testCatalog = sparkSession.catalog("testcat").asTableCatalog + val testCatalog = getTestCatalog() val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) assert(table.name == "testcat.table_name") @@ -274,30 +270,25 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn .add("id", LongType, nullable = false) .add("data", StringType)) - val rdd = sparkSession.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) - checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), sparkSession.table("source")) + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) } test("CreateTableAsSelect: v2 session catalog can load v1 source table") { - val sparkSession = spark.newSession() - sparkSession.conf.set("spark.sql.catalog.session", classOf[V2SessionCatalog].getName) - - val df = sparkSession.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") + val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") df.createOrReplaceTempView("source") - - sparkSession.sql(s"CREATE TABLE table_name USING parquet AS SELECT id, data FROM source") - - // use the catalog name to force loading with the v2 catalog - checkAnswer(sparkSession.sql(s"TABLE session.table_name"), sparkSession.table("source")) + spark.sql(s"CREATE TABLE table_name USING parquet AS SELECT id, data FROM source") + checkAnswer(spark.table(s"table_name"), spark.table("source")) } test("DropTable: basic") { val tableName = "testcat.ns1.ns2.tbl" val ident = Identifier.of(Array("ns1", "ns2"), "tbl") + val testCatalog = getTestCatalog() sql(s"CREATE TABLE $tableName USING foo AS SELECT id, data FROM source") - assert(spark.catalog("testcat").asTableCatalog.tableExists(ident) === true) + assert(testCatalog.tableExists(ident) === true) sql(s"DROP TABLE $tableName") - assert(spark.catalog("testcat").asTableCatalog.tableExists(ident) === false) + assert(testCatalog.tableExists(ident) === false) } test("DropTable: if exists") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2fa108825982..d2caf0165d44 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -68,7 +68,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session /** * A logical query plan `Analyzer` with rules specific to Hive. */ - override protected def analyzer: Analyzer = new Analyzer(catalog, conf) { + override protected def analyzer: Analyzer = new Analyzer(catalog, catalogManager, conf) { override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = new ResolveHiveSerdeTable(session) +: new FindDataSourceTable(session) +: @@ -92,8 +92,6 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session V2WriteSupportCheck +: V2StreamingScanSupportCheck +: customCheckRules - - override protected def lookupCatalog(name: String): CatalogPlugin = session.catalog(name) } /**