diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index ac242ba3d135..9426c9708a8d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicInteger import scala.reflect.ClassTag +import scala.util.Try import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.clients.producer.internals.DefaultPartitioner @@ -500,7 +501,7 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { TestUtils.assertExceptionMsg(ex, "null topic present in the data") } - protected def testUnsupportedSaveModes(msg: (SaveMode) => String): Unit = { + protected def testUnsupportedSaveModes(msg: (SaveMode) => Seq[String]): Unit = { val topic = newTopic() testUtils.createTopic(topic) val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") @@ -513,7 +514,10 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { .mode(mode) .save() } - TestUtils.assertExceptionMsg(ex, msg(mode)) + val errorChecks = msg(mode).map(m => Try(TestUtils.assertExceptionMsg(ex, m))) + if (!errorChecks.exists(_.isSuccess)) { + fail("Error messages not found in exception trace") + } } } @@ -541,7 +545,7 @@ class KafkaSinkBatchSuiteV1 extends KafkaSinkBatchSuiteBase { .set(SQLConf.USE_V1_SOURCE_LIST, "kafka") test("batch - unsupported save modes") { - testUnsupportedSaveModes((mode) => s"Save mode ${mode.name} not allowed for Kafka") + testUnsupportedSaveModes((mode) => s"Save mode ${mode.name} not allowed for Kafka" :: Nil) } } @@ -552,7 +556,8 @@ class KafkaSinkBatchSuiteV2 extends KafkaSinkBatchSuiteBase { .set(SQLConf.USE_V1_SOURCE_LIST, "") test("batch - unsupported save modes") { - testUnsupportedSaveModes((mode) => s"cannot be written with ${mode.name} mode") + testUnsupportedSaveModes((mode) => + Seq(s"cannot be written with ${mode.name} mode", "does not support truncate")) } test("generic - write big data with small producer buffer") { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java new file mode 100644 index 000000000000..5225b12788c4 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * An interface, which TableProviders can implement, to support table existence checks and creation + * through a catalog, without having to use table identifiers. For example, when file based data + * sources use the `DataFrameWriter.save(path)` method, the option `path` can translate to a + * PathIdentifier. A catalog can then use this PathIdentifier to check the existence of a table, or + * whether a table can be created at a given directory. + */ +@Evolving +public interface SupportsCatalogOptions extends TableProvider { + /** + * Return a {@link Identifier} instance that can identify a table for a DataSource given + * DataFrame[Reader|Writer] options. + * + * @param options the user-specified options that can identify a table, e.g. file path, Kafka + * topic name, etc. It's an immutable case-insensitive string-to-string map. + */ + Identifier extractIdentifier(CaseInsensitiveStringMap options); + + /** + * Return the name of a catalog that can be used to check the existence of, load, and create + * a table for this DataSource given the identifier that will be extracted by + * {@link #extractIdentifier(CaseInsensitiveStringMap) extractIdentifier}. A `null` value can + * be used to defer to the V2SessionCatalog. + * + * @param options the user-specified options that can identify a table, e.g. file path, Kafka + * topic name, etc. It's an immutable case-insensitive string-to-string map. + */ + default String extractCatalog(CaseInsensitiveStringMap options) { + return CatalogManager.SESSION_CATALOG_NAME(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 2f4914dd7db3..671beb3ab150 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.AlterTable import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap private[sql] object CatalogV2Util { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -315,4 +316,14 @@ private[sql] object CatalogV2Util { val unresolved = UnresolvedV2Relation(originalNameParts, tableCatalog, ident) AlterTable(tableCatalog, ident, unresolved, changes) } + + def getTableProviderCatalog( + provider: SupportsCatalogOptions, + catalogManager: CatalogManager, + options: CaseInsensitiveStringMap): TableCatalog = { + Option(provider.extractCatalog(options)) + .map(catalogManager.catalog) + .getOrElse(catalogManager.v2SessionCatalog) + .asTableCatalog + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 8570e4640fee..30d0c851964d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, Univocit import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.FailureSafeParser -import org.apache.spark.sql.connector.catalog.SupportsRead +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsCatalogOptions, SupportsRead} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource @@ -215,9 +215,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) - val table = userSpecifiedSchema match { - case Some(schema) => provider.getTable(dsOptions, schema) - case _ => provider.getTable(dsOptions) + val table = provider match { + case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty => + throw new IllegalArgumentException( + s"$source does not support user specified schema. Please don't specify the schema.") + case hasCatalog: SupportsCatalogOptions => + val ident = hasCatalog.extractIdentifier(dsOptions) + val catalog = CatalogV2Util.getTableProviderCatalog( + hasCatalog, + sparkSession.sessionState.catalogManager, + dsOptions) + catalog.loadTable(ident) + case _ => + userSpecifiedSchema match { + case Some(schema) => provider.getTable(dsOptions, schema) + case _ => provider.getTable(dsOptions) + } } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 2b124ae260ca..998ec9ebdff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, SupportsWrite, TableCatalog, TableProvider, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, SupportsWrite, Table, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LiteralValue, Transform} import org.apache.spark.sql.execution.SQLExecution @@ -258,37 +258,77 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val dsOptions = new CaseInsensitiveStringMap(options.asJava) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ - provider.getTable(dsOptions) match { - case table: SupportsWrite if table.supports(BATCH_WRITE) => - if (partitioningColumns.nonEmpty) { - throw new AnalysisException("Cannot write data to TableProvider implementation " + - "if partition columns are specified.") - } - lazy val relation = DataSourceV2Relation.create(table, dsOptions) - mode match { - case SaveMode.Append => - runCommand(df.sparkSession, "save") { - AppendData.byName(relation, df.logicalPlan, extraOptions.toMap) + mode match { + case SaveMode.Append | SaveMode.Overwrite => + val table = provider match { + case supportsExtract: SupportsCatalogOptions => + val ident = supportsExtract.extractIdentifier(dsOptions) + val sessionState = df.sparkSession.sessionState + val catalog = CatalogV2Util.getTableProviderCatalog( + supportsExtract, sessionState.catalogManager, dsOptions) + + catalog.loadTable(ident) + case tableProvider: TableProvider => + val t = tableProvider.getTable(dsOptions) + if (t.supports(BATCH_WRITE)) { + t + } else { + // Streaming also uses the data source V2 API. So it may be that the data source + // implements v2, but has no v2 implementation for batch writes. In that case, we + // fall back to saving as though it's a V1 source. + return saveToV1Source() } + } + + val relation = DataSourceV2Relation.create(table, dsOptions) + checkPartitioningMatchesV2Table(table) + if (mode == SaveMode.Append) { + runCommand(df.sparkSession, "save") { + AppendData.byName(relation, df.logicalPlan, extraOptions.toMap) + } + } else { + // Truncate the table. TableCapabilityCheck will throw a nice exception if this + // isn't supported + runCommand(df.sparkSession, "save") { + OverwriteByExpression.byName( + relation, df.logicalPlan, Literal(true), extraOptions.toMap) + } + } + + case createMode => + provider match { + case supportsExtract: SupportsCatalogOptions => + val ident = supportsExtract.extractIdentifier(dsOptions) + val sessionState = df.sparkSession.sessionState + val catalog = CatalogV2Util.getTableProviderCatalog( + supportsExtract, sessionState.catalogManager, dsOptions) + + val location = Option(dsOptions.get("path")).map(TableCatalog.PROP_LOCATION -> _) - case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => - // truncate the table runCommand(df.sparkSession, "save") { - OverwriteByExpression.byName( - relation, df.logicalPlan, Literal(true), extraOptions.toMap) + CreateTableAsSelect( + catalog, + ident, + partitioningAsV2, + df.queryExecution.analyzed, + Map(TableCatalog.PROP_PROVIDER -> source) ++ location, + extraOptions.toMap, + ignoreIfExists = createMode == SaveMode.Ignore) + } + case tableProvider: TableProvider => + if (tableProvider.getTable(dsOptions).supports(BATCH_WRITE)) { + throw new AnalysisException(s"TableProvider implementation $source cannot be " + + s"written with $createMode mode, please use Append or Overwrite " + + "modes instead.") + } else { + // Streaming also uses the data source V2 API. So it may be that the data source + // implements v2, but has no v2 implementation for batch writes. In that case, we + // fallback to saving as though it's a V1 source. + saveToV1Source() } - - case other => - throw new AnalysisException(s"TableProvider implementation $source cannot be " + - s"written with $other mode, please use Append or Overwrite " + - "modes instead.") } - - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving - // as though it's a V1 source. - case _ => saveToV1Source() } + } else { saveToV1Source() } @@ -504,14 +544,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def saveAsTable(catalog: TableCatalog, ident: Identifier): Unit = { - val partitioning = partitioningColumns.map { colNames => - colNames.map(name => IdentityTransform(FieldReference(name))) - }.getOrElse(Seq.empty[Transform]) - val bucketing = bucketColumnNames.map { cols => - Seq(BucketTransform(LiteralValue(numBuckets.get, IntegerType), cols.map(FieldReference(_)))) - }.getOrElse(Seq.empty[Transform]) - val partitionTransforms = partitioning ++ bucketing - val tableOpt = try Option(catalog.loadTable(ident)) catch { case _: NoSuchTableException => None } @@ -526,13 +558,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { return saveAsTable(TableIdentifier(ident.name(), ident.namespace().headOption)) case (SaveMode.Append, Some(table)) => + checkPartitioningMatchesV2Table(table) AppendData.byName(DataSourceV2Relation.create(table), df.logicalPlan, extraOptions.toMap) case (SaveMode.Overwrite, _) => ReplaceTableAsSelect( catalog, ident, - partitionTransforms, + partitioningAsV2, df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ getLocationIfExists, extraOptions.toMap, @@ -545,7 +578,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { CreateTableAsSelect( catalog, ident, - partitionTransforms, + partitioningAsV2, df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ getLocationIfExists, extraOptions.toMap, @@ -623,6 +656,29 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { CreateTable(tableDesc, mode, Some(df.logicalPlan))) } + /** Converts the provided partitioning and bucketing information to DataSourceV2 Transforms. */ + private def partitioningAsV2: Seq[Transform] = { + val partitioning = partitioningColumns.map { colNames => + colNames.map(name => IdentityTransform(FieldReference(name))) + }.getOrElse(Seq.empty[Transform]) + val bucketing = + getBucketSpec.map(spec => CatalogV2Implicits.BucketSpecHelper(spec).asTransform).toSeq + partitioning ++ bucketing + } + + /** + * For V2 DataSources, performs if the provided partitioning matches that of the table. + * Partitioning information is not required when appending data to V2 tables. + */ + private def checkPartitioningMatchesV2Table(existingTable: Table): Unit = { + val v2Partitions = partitioningAsV2 + if (v2Partitions.isEmpty) return + require(v2Partitions.sameElements(existingTable.partitioning()), + "The provided partitioning does not match of the table.\n" + + s" - provided: ${v2Partitions.mkString(", ")}\n" + + s" - table: ${existingTable.partitioning().mkString(", ")}") + } + /** * Saves the content of the `DataFrame` to an external database table via JDBC. In the case the * table already exists in the external database, behavior of this function depends on the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala new file mode 100644 index 000000000000..0148bb07ee96 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util + +import scala.language.implicitConversions +import scala.util.Try + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions, TableCatalog} +import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME +import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{LongType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { + + import testImplicits._ + + private val catalogName = "testcat" + private val format = classOf[CatalogSupportingInMemoryTableProvider].getName + + private def catalog(name: String): TableCatalog = { + spark.sessionState.catalogManager.catalog(name).asInstanceOf[TableCatalog] + } + + private implicit def stringToIdentifier(value: String): Identifier = { + Identifier.of(Array.empty, value) + } + + before { + spark.conf.set( + V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName) + spark.conf.set( + s"spark.sql.catalog.$catalogName", classOf[InMemoryTableCatalog].getName) + } + + override def afterEach(): Unit = { + super.afterEach() + Try(catalog(SESSION_CATALOG_NAME).asInstanceOf[InMemoryTableSessionCatalog].clearTables()) + catalog(catalogName).listTables(Array.empty).foreach( + catalog(catalogName).dropTable(_)) + spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) + spark.conf.unset(s"spark.sql.catalog.$catalogName") + } + + private def testCreateAndRead( + saveMode: SaveMode, + withCatalogOption: Option[String], + partitionBy: Seq[String]): Unit = { + val df = spark.range(10).withColumn("part", 'id % 5) + val dfw = df.write.format(format).mode(saveMode).option("name", "t1") + withCatalogOption.foreach(cName => dfw.option("catalog", cName)) + dfw.partitionBy(partitionBy: _*).save() + + val table = catalog(withCatalogOption.getOrElse(SESSION_CATALOG_NAME)).loadTable("t1") + val namespace = withCatalogOption.getOrElse("default") + assert(table.name() === s"$namespace.t1", "Table identifier was wrong") + assert(table.partitioning().length === partitionBy.length, "Partitioning did not match") + if (partitionBy.nonEmpty) { + table.partitioning.head match { + case IdentityTransform(FieldReference(field)) => + assert(field === Seq(partitionBy.head), "Partitioning column did not match") + case otherTransform => + fail(s"Unexpected partitioning ${otherTransform.describe()} received") + } + } + assert(table.partitioning().map(_.references().head.fieldNames().head) === partitionBy, + "Partitioning was incorrect") + assert(table.schema() === df.schema.asNullable, "Schema did not match") + + checkAnswer(load("t1", withCatalogOption), df.toDF()) + } + + test(s"save works with ErrorIfExists - no table, no partitioning, session catalog") { + testCreateAndRead(SaveMode.ErrorIfExists, None, Nil) + } + + test(s"save works with ErrorIfExists - no table, with partitioning, session catalog") { + testCreateAndRead(SaveMode.ErrorIfExists, None, Seq("part")) + } + + test(s"save works with Ignore - no table, no partitioning, testcat catalog") { + testCreateAndRead(SaveMode.Ignore, Some(catalogName), Nil) + } + + test(s"save works with Ignore - no table, with partitioning, testcat catalog") { + testCreateAndRead(SaveMode.Ignore, Some(catalogName), Seq("part")) + } + + test("save fails with ErrorIfExists if table exists - session catalog") { + sql(s"create table t1 (id bigint) using $format") + val df = spark.range(10) + intercept[TableAlreadyExistsException] { + val dfw = df.write.format(format).option("name", "t1") + dfw.save() + } + } + + test("save fails with ErrorIfExists if table exists - testcat catalog") { + sql(s"create table $catalogName.t1 (id bigint) using $format") + val df = spark.range(10) + intercept[TableAlreadyExistsException] { + val dfw = df.write.format(format).option("name", "t1").option("catalog", catalogName) + dfw.save() + } + } + + test("Ignore mode if table exists - session catalog") { + sql(s"create table t1 (id bigint) using $format") + val df = spark.range(10).withColumn("part", 'id % 5) + val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") + dfw.save() + + val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") + assert(table.partitioning().isEmpty, "Partitioning should be empty") + assert(table.schema() === new StructType().add("id", LongType), "Schema did not match") + assert(load("t1", None).count() === 0) + } + + test("Ignore mode if table exists - testcat catalog") { + sql(s"create table $catalogName.t1 (id bigint) using $format") + val df = spark.range(10).withColumn("part", 'id % 5) + val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") + dfw.option("catalog", catalogName).save() + + val table = catalog(catalogName).loadTable("t1") + assert(table.partitioning().isEmpty, "Partitioning should be empty") + assert(table.schema() === new StructType().add("id", LongType), "Schema did not match") + assert(load("t1", Some(catalogName)).count() === 0) + } + + test("append and overwrite modes - session catalog") { + sql(s"create table t1 (id bigint) using $format") + val df = spark.range(10) + df.write.format(format).option("name", "t1").mode(SaveMode.Append).save() + + checkAnswer(load("t1", None), df.toDF()) + + val df2 = spark.range(10, 20) + df2.write.format(format).option("name", "t1").mode(SaveMode.Overwrite).save() + + checkAnswer(load("t1", None), df2.toDF()) + } + + test("append and overwrite modes - testcat catalog") { + sql(s"create table $catalogName.t1 (id bigint) using $format") + val df = spark.range(10) + df.write.format(format).option("name", "t1").option("catalog", catalogName) + .mode(SaveMode.Append).save() + + checkAnswer(load("t1", Some(catalogName)), df.toDF()) + + val df2 = spark.range(10, 20) + df2.write.format(format).option("name", "t1").option("catalog", catalogName) + .mode(SaveMode.Overwrite).save() + + checkAnswer(load("t1", Some(catalogName)), df2.toDF()) + } + + test("fail on user specified schema when reading - session catalog") { + sql(s"create table t1 (id bigint) using $format") + val e = intercept[IllegalArgumentException] { + spark.read.format(format).option("name", "t1").schema("id bigint").load() + } + assert(e.getMessage.contains("not support user specified schema")) + } + + test("fail on user specified schema when reading - testcat catalog") { + sql(s"create table $catalogName.t1 (id bigint) using $format") + val e = intercept[IllegalArgumentException] { + spark.read.format(format).option("name", "t1").option("catalog", catalogName) + .schema("id bigint").load() + } + assert(e.getMessage.contains("not support user specified schema")) + } + + private def load(name: String, catalogOpt: Option[String]): DataFrame = { + val dfr = spark.read.format(format).option("name", "t1") + catalogOpt.foreach(cName => dfr.option("catalog", cName)) + dfr.load() + } +} + +class CatalogSupportingInMemoryTableProvider + extends InMemoryTableProvider + with SupportsCatalogOptions { + + override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = { + val name = options.get("name") + assert(name != null, "The name should be provided for this table") + Identifier.of(Array.empty, name) + } + + override def extractCatalog(options: CaseInsensitiveStringMap): String = { + options.get("catalog") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala index d03294cb4067..3f6ac0b7f8d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala @@ -74,6 +74,11 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating t } + override def dropTable(ident: Identifier): Boolean = { + tables.remove(fullIdentifier(ident)) + super.dropTable(ident) + } + def clearTables(): Unit = { assert(!tables.isEmpty, "Tables were empty, maybe didn't use the session catalog code path?") tables.keySet().asScala.foreach(super.dropTable)