From 6c916c2714add2926e182c506fd203e38b10c925 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 16 Dec 2019 11:02:57 -0800 Subject: [PATCH 01/12] Interface definition --- .../catalog/SupportsCatalogOptions.java | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java new file mode 100644 index 000000000000..b0b52ae34c74 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * An interface, which TableProviders can implement, to support table existence checks and creation + * through a catalog, without having to use table identifiers. For example, when file based data + * sources use the `DataFrameWriter.save(path)` method, the option `path` can translate to a + * PathIdentifier. A catalog can then use this PathIdentifier to check the existence of a table, or + * whether a table can be created at a given directory. + */ +@Evolving +public interface SupportsCatalogOptions extends TableProvider { + /** + * Return a {@link Identifier} instance that can identify a table for a DataSource given + * DataFrame[Reader|Writer] options. + * + * @param options the user-specified options that can identify a table, e.g. file path, Kafka + * topic name, etc. It's an immutable case-insensitive string-to-string map. + */ + Identifier extractIdentifier(CaseInsensitiveStringMap options); + + /** + * Return the name of a catalog that can be used to check the existence of, load, and create + * a table for this DataSource given the identifier that will be extracted by + * {@see extractIdentifier}. A `null` value can be used to defer to the V2SessionCatalog. + * + * @param options the user-specified options that can identify a table, e.g. file path, Kafka + * topic name, etc. It's an immutable case-insensitive string-to-string map. + */ + default String extractCatalog(CaseInsensitiveStringMap options) { + return null; + } +} From ed9adc8b0cafb97fb61b20fe7f61c8a5bd70079e Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 16 Dec 2019 12:24:36 -0800 Subject: [PATCH 02/12] save implementation --- .../apache/spark/sql/DataFrameWriter.scala | 46 ++++++++++++++----- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 2b124ae260ca..5aea2a032e81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, SupportsWrite, TableCatalog, TableProvider, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Catalogs, Identifier, SupportsCatalogOptions, SupportsWrite, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LiteralValue, Transform} import org.apache.spark.sql.execution.SQLExecution @@ -278,6 +278,28 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { relation, df.logicalPlan, Literal(true), extraOptions.toMap) } + case other if classOf[SupportsCatalogOptions].isAssignableFrom(provider.getClass) => + val catalogOptions = provider.asInstanceOf[SupportsCatalogOptions] + val ident = catalogOptions.extractIdentifier(dsOptions) + val sessionState = df.sparkSession.sessionState + val catalog = Option(catalogOptions.extractCatalog(dsOptions)) + .map(Catalogs.load(_, sessionState.conf)) + .getOrElse(sessionState.catalogManager.v2SessionCatalog) + .asInstanceOf[TableCatalog] + + val location = Option(dsOptions.get("path")).map(TableCatalog.PROP_LOCATION -> _) + + runCommand(df.sparkSession, "save") { + CreateTableAsSelect( + catalog, + ident, + getV2Transforms, + df.queryExecution.analyzed, + Map(TableCatalog.PROP_PROVIDER -> source) ++ location, + extraOptions.toMap, + ignoreIfExists = other == SaveMode.Ignore) + } + case other => throw new AnalysisException(s"TableProvider implementation $source cannot be " + s"written with $other mode, please use Append or Overwrite " + @@ -504,14 +526,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def saveAsTable(catalog: TableCatalog, ident: Identifier): Unit = { - val partitioning = partitioningColumns.map { colNames => - colNames.map(name => IdentityTransform(FieldReference(name))) - }.getOrElse(Seq.empty[Transform]) - val bucketing = bucketColumnNames.map { cols => - Seq(BucketTransform(LiteralValue(numBuckets.get, IntegerType), cols.map(FieldReference(_)))) - }.getOrElse(Seq.empty[Transform]) - val partitionTransforms = partitioning ++ bucketing - val tableOpt = try Option(catalog.loadTable(ident)) catch { case _: NoSuchTableException => None } @@ -532,7 +546,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { ReplaceTableAsSelect( catalog, ident, - partitionTransforms, + getV2Transforms, df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ getLocationIfExists, extraOptions.toMap, @@ -545,7 +559,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { CreateTableAsSelect( catalog, ident, - partitionTransforms, + getV2Transforms, df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ getLocationIfExists, extraOptions.toMap, @@ -623,6 +637,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { CreateTable(tableDesc, mode, Some(df.logicalPlan))) } + private def getV2Transforms: Seq[Transform] = { + val partitioning = partitioningColumns.map { colNames => + colNames.map(name => IdentityTransform(FieldReference(name))) + }.getOrElse(Seq.empty[Transform]) + val bucketing = bucketColumnNames.map { cols => + Seq(BucketTransform(LiteralValue(numBuckets.get, IntegerType), cols.map(FieldReference(_)))) + }.getOrElse(Seq.empty[Transform]) + partitioning ++ bucketing + } + /** * Saves the content of the `DataFrame` to an external database table via JDBC. In the case the * table already exists in the external database, behavior of this function depends on the From 0a87228ee44773c300b44b3b2e8c98738fd4723d Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 16 Dec 2019 14:37:02 -0800 Subject: [PATCH 03/12] Added partitioning checks --- .../apache/spark/sql/DataFrameWriter.scala | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5aea2a032e81..b81d2bad44b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Catalogs, Identifier, SupportsCatalogOptions, SupportsWrite, TableCatalog, TableProvider, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Catalogs, Identifier, SupportsCatalogOptions, SupportsWrite, Table, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LiteralValue, Transform} import org.apache.spark.sql.execution.SQLExecution @@ -260,18 +260,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ provider.getTable(dsOptions) match { case table: SupportsWrite if table.supports(BATCH_WRITE) => - if (partitioningColumns.nonEmpty) { - throw new AnalysisException("Cannot write data to TableProvider implementation " + - "if partition columns are specified.") - } lazy val relation = DataSourceV2Relation.create(table, dsOptions) mode match { case SaveMode.Append => + verifyV2Partitioning(table) runCommand(df.sparkSession, "save") { AppendData.byName(relation, df.logicalPlan, extraOptions.toMap) } case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => + verifyV2Partitioning(table) // truncate the table runCommand(df.sparkSession, "save") { OverwriteByExpression.byName( @@ -540,6 +538,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { return saveAsTable(TableIdentifier(ident.name(), ident.namespace().headOption)) case (SaveMode.Append, Some(table)) => + verifyV2Partitioning(table) AppendData.byName(DataSourceV2Relation.create(table), df.logicalPlan, extraOptions.toMap) case (SaveMode.Overwrite, _) => @@ -637,6 +636,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { CreateTable(tableDesc, mode, Some(df.logicalPlan))) } + /** Converts the provided partitioning and bucketing information to DataSourceV2 Transforms. */ private def getV2Transforms: Seq[Transform] = { val partitioning = partitioningColumns.map { colNames => colNames.map(name => IdentityTransform(FieldReference(name))) @@ -647,6 +647,19 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitioning ++ bucketing } + /** + * For V2 DataSources, performs if the provided partitioning matches that of the table. + * Partitioning information is not required when appending data to V2 tables. + */ + private def verifyV2Partitioning(existingTable: Table): Unit = { + val v2Partitions = getV2Transforms + if (v2Partitions.isEmpty) return + require(v2Partitions.sameElements(existingTable.partitioning()), + "The provided partitioning does not match of the table.\n" + + s" - provided: ${v2Partitions.mkString(", ")}\n" + + s" - table: ${existingTable.partitioning().mkString(", ")}") + } + /** * Saves the content of the `DataFrame` to an external database table via JDBC. In the case the * table already exists in the external database, behavior of this function depends on the From 1578f6c4cf072ba79bae74a02c1a48b16d7d92f0 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 17 Dec 2019 09:20:14 -0800 Subject: [PATCH 04/12] Update SupportsCatalogOptions.java --- .../spark/sql/connector/catalog/SupportsCatalogOptions.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java index b0b52ae34c74..2c55c7360641 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java @@ -41,7 +41,8 @@ public interface SupportsCatalogOptions extends TableProvider { /** * Return the name of a catalog that can be used to check the existence of, load, and create * a table for this DataSource given the identifier that will be extracted by - * {@see extractIdentifier}. A `null` value can be used to defer to the V2SessionCatalog. + * {@link #extractIdentifier(CaseInsensitiveStringMap) extractIdentifier}. A `null` value can + * be used to defer to the V2SessionCatalog. * * @param options the user-specified options that can identify a table, e.g. file path, Kafka * topic name, etc. It's an immutable case-insensitive string-to-string map. From a4416045827dece6128afb253c30d5c0862d594a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 19 Dec 2019 18:02:53 -0800 Subject: [PATCH 05/12] Added first set of tests --- .../sql/connector/catalog/CatalogV2Util.scala | 11 ++ .../apache/spark/sql/DataFrameReader.scala | 18 ++- .../apache/spark/sql/DataFrameWriter.scala | 26 ++-- .../SupportsCatalogOptionsSuite.scala | 134 ++++++++++++++++++ 4 files changed, 171 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 2f4914dd7db3..671beb3ab150 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.AlterTable import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap private[sql] object CatalogV2Util { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -315,4 +316,14 @@ private[sql] object CatalogV2Util { val unresolved = UnresolvedV2Relation(originalNameParts, tableCatalog, ident) AlterTable(tableCatalog, ident, unresolved, changes) } + + def getTableProviderCatalog( + provider: SupportsCatalogOptions, + catalogManager: CatalogManager, + options: CaseInsensitiveStringMap): TableCatalog = { + Option(provider.extractCatalog(options)) + .map(catalogManager.catalog) + .getOrElse(catalogManager.v2SessionCatalog) + .asTableCatalog + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 8570e4640fee..ab3bbccb721e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, Univocit import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.FailureSafeParser -import org.apache.spark.sql.connector.catalog.SupportsRead +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsCatalogOptions, SupportsRead} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource @@ -215,9 +215,19 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) - val table = userSpecifiedSchema match { - case Some(schema) => provider.getTable(dsOptions, schema) - case _ => provider.getTable(dsOptions) + val table = provider match { + case hasCatalog: SupportsCatalogOptions => + val ident = hasCatalog.extractIdentifier(dsOptions) + val catalog = CatalogV2Util.getTableProviderCatalog( + hasCatalog, + sparkSession.sessionState.catalogManager, + dsOptions) + catalog.loadTable(ident) + case other => + userSpecifiedSchema match { + case Some(schema) => provider.getTable(dsOptions, schema) + case _ => provider.getTable(dsOptions) + } } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b81d2bad44b3..d8fb71cf3082 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Catalogs, Identifier, SupportsCatalogOptions, SupportsWrite, Table, TableCatalog, TableProvider, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Util, Identifier, SupportsCatalogOptions, SupportsWrite, Table, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LiteralValue, Transform} import org.apache.spark.sql.execution.SQLExecution @@ -263,13 +263,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { lazy val relation = DataSourceV2Relation.create(table, dsOptions) mode match { case SaveMode.Append => - verifyV2Partitioning(table) + checkPartitioningMatchesV2Table(table) runCommand(df.sparkSession, "save") { AppendData.byName(relation, df.logicalPlan, extraOptions.toMap) } case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => - verifyV2Partitioning(table) + checkPartitioningMatchesV2Table(table) // truncate the table runCommand(df.sparkSession, "save") { OverwriteByExpression.byName( @@ -280,10 +280,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val catalogOptions = provider.asInstanceOf[SupportsCatalogOptions] val ident = catalogOptions.extractIdentifier(dsOptions) val sessionState = df.sparkSession.sessionState - val catalog = Option(catalogOptions.extractCatalog(dsOptions)) - .map(Catalogs.load(_, sessionState.conf)) - .getOrElse(sessionState.catalogManager.v2SessionCatalog) - .asInstanceOf[TableCatalog] + val catalog = CatalogV2Util.getTableProviderCatalog( + catalogOptions, sessionState.catalogManager, dsOptions) val location = Option(dsOptions.get("path")).map(TableCatalog.PROP_LOCATION -> _) @@ -291,7 +289,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { CreateTableAsSelect( catalog, ident, - getV2Transforms, + partitioningAsV2, df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ location, extraOptions.toMap, @@ -538,14 +536,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { return saveAsTable(TableIdentifier(ident.name(), ident.namespace().headOption)) case (SaveMode.Append, Some(table)) => - verifyV2Partitioning(table) + checkPartitioningMatchesV2Table(table) AppendData.byName(DataSourceV2Relation.create(table), df.logicalPlan, extraOptions.toMap) case (SaveMode.Overwrite, _) => ReplaceTableAsSelect( catalog, ident, - getV2Transforms, + partitioningAsV2, df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ getLocationIfExists, extraOptions.toMap, @@ -558,7 +556,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { CreateTableAsSelect( catalog, ident, - getV2Transforms, + partitioningAsV2, df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ getLocationIfExists, extraOptions.toMap, @@ -637,7 +635,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } /** Converts the provided partitioning and bucketing information to DataSourceV2 Transforms. */ - private def getV2Transforms: Seq[Transform] = { + private def partitioningAsV2: Seq[Transform] = { val partitioning = partitioningColumns.map { colNames => colNames.map(name => IdentityTransform(FieldReference(name))) }.getOrElse(Seq.empty[Transform]) @@ -651,8 +649,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * For V2 DataSources, performs if the provided partitioning matches that of the table. * Partitioning information is not required when appending data to V2 tables. */ - private def verifyV2Partitioning(existingTable: Table): Unit = { - val v2Partitions = getV2Transforms + private def checkPartitioningMatchesV2Table(existingTable: Table): Unit = { + val v2Partitions = partitioningAsV2 if (v2Partitions.isEmpty) return require(v2Partitions.sameElements(existingTable.partitioning()), "The provided partitioning does not match of the table.\n" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala new file mode 100644 index 000000000000..0a77c1710761 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import scala.language.implicitConversions + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{QueryTest, SaveMode} +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions, TableCatalog} +import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME +import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{LongType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { + + import testImplicits._ + + private val catalogName = "testcat" + + private def catalog(name: String): InMemoryTableSessionCatalog = { + spark.sessionState.catalogManager.catalog(name).asInstanceOf[InMemoryTableSessionCatalog] + } + + private implicit def stringToIdentifier(value: String): Identifier = { + Identifier.of(Array.empty, value) + } + + before { + spark.conf.set( + V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName) + spark.conf.set( + s"spark.sql.catalog.$catalogName", classOf[InMemoryTableSessionCatalog].getName) + } + + override def afterEach(): Unit = { + super.afterEach() + catalog(SESSION_CATALOG_NAME).clearTables() + catalog(catalogName).clearTables() + spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) + spark.conf.unset(s"spark.sql.catalog.$catalogName") + } + + def dataFrameWriterTests(withCatalogOption: Option[String]): Unit = { + Seq(SaveMode.ErrorIfExists, SaveMode.Ignore).foreach { saveMode => + test(s"save works with $saveMode - no table, no partitioning, session catalog") { + val df = spark.range(10) + val dfw = df.write.mode(saveMode).option("name", "t1") + withCatalogOption.foreach(cName => dfw.option("catalog", cName)) + dfw.save() + + val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") + assert(table.name() === "t1", "Table identifier was wrong") + assert(table.partitioning().isEmpty, "Partitioning should be empty") + assert(table.schema() === df.schema.asNullable, "Schema did not match") + } + + test(s"save works with $saveMode - no table, with partitioning, session catalog") { + val df = spark.range(10).withColumn("part", 'id % 5) + val dfw = df.write.mode(saveMode).option("name", "t1").partitionBy("part") + withCatalogOption.foreach(cName => dfw.option("catalog", cName)) + dfw.save() + + val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") + assert(table.name() === "t1", "Table identifier was wrong") + assert(table.partitioning().length === 1, "Partitioning should not be empty") + assert(table.partitioning().head.references().head.fieldNames().head === "part", + "Partitioning was incorrect") + assert(table.schema() === df.schema.asNullable, "Schema did not match") + } + } + + test("save fails with ErrorIfExists if table exists") { + sql("create table t1 (id bigint) using foo") + val df = spark.range(10) + intercept[TableAlreadyExistsException] { + val dfw = df.write.option("name", "t1") + withCatalogOption.foreach(cName => dfw.option("catalog", cName)) + dfw.save() + } + } + + test("Ignore mode if table exists") { + sql("create table t1 (id bigint) using foo") + val df = spark.range(10).withColumn("part", 'id % 5) + intercept[TableAlreadyExistsException] { + val dfw = df.write.mode(SaveMode.Ignore).option("name", "t1") + withCatalogOption.foreach(cName => dfw.option("catalog", cName)) + dfw.save() + } + + val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") + assert(table.partitioning().isEmpty, "Partitioning should be empty") + assert(table.schema() === new StructType().add("id", LongType), "Schema did not match") + } + } + + dataFrameWriterTests(None) + + dataFrameWriterTests(Some(catalogName)) +} + +class CatalogSupportingInMemoryTableProvider + extends InMemoryTableProvider + with SupportsCatalogOptions { + + override def extractIdentifier(options: CaseInsensitiveStringMap): Identifier = { + val name = options.get("name") + assert(name != null, "The name should be provided for this table") + Identifier.of(Array.empty, name) + } + + override def extractCatalog(options: CaseInsensitiveStringMap): String = { + options.get("catalog") + } +} From 5c11b9427e102fda919b5f206eb4a93312f34b06 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 19 Dec 2019 18:12:42 -0800 Subject: [PATCH 06/12] Added more tests --- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../SupportsCatalogOptionsSuite.scala | 27 ++++++++++++------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ab3bbccb721e..c67104384be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -223,7 +223,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { sparkSession.sessionState.catalogManager, dsOptions) catalog.loadTable(ident) - case other => + case _ => userSpecifiedSchema match { case Some(schema) => provider.getTable(dsOptions, schema) case _ => provider.getTable(dsOptions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 0a77c1710761..1381811d7e81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -35,6 +35,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with import testImplicits._ private val catalogName = "testcat" + private val format = classOf[CatalogSupportingInMemoryTableProvider].getName private def catalog(name: String): InMemoryTableSessionCatalog = { spark.sessionState.catalogManager.catalog(name).asInstanceOf[InMemoryTableSessionCatalog] @@ -59,32 +60,40 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with spark.conf.unset(s"spark.sql.catalog.$catalogName") } - def dataFrameWriterTests(withCatalogOption: Option[String]): Unit = { + def testWithDifferentCatalogs(withCatalogOption: Option[String]): Unit = { Seq(SaveMode.ErrorIfExists, SaveMode.Ignore).foreach { saveMode => test(s"save works with $saveMode - no table, no partitioning, session catalog") { val df = spark.range(10) - val dfw = df.write.mode(saveMode).option("name", "t1") + val dfw = df.write.format(format).mode(saveMode).option("name", "t1") withCatalogOption.foreach(cName => dfw.option("catalog", cName)) dfw.save() - val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") + val table = catalog(withCatalogOption.getOrElse(SESSION_CATALOG_NAME)).loadTable("t1") assert(table.name() === "t1", "Table identifier was wrong") assert(table.partitioning().isEmpty, "Partitioning should be empty") assert(table.schema() === df.schema.asNullable, "Schema did not match") + + val dfr = spark.read.format(format).option("name", "t1") + withCatalogOption.foreach(cName => dfr.option("catalog", cName)) + checkAnswer(dfr.load(), df.toDF()) } test(s"save works with $saveMode - no table, with partitioning, session catalog") { val df = spark.range(10).withColumn("part", 'id % 5) - val dfw = df.write.mode(saveMode).option("name", "t1").partitionBy("part") + val dfw = df.write.format(format).mode(saveMode).option("name", "t1").partitionBy("part") withCatalogOption.foreach(cName => dfw.option("catalog", cName)) dfw.save() - val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") + val table = catalog(withCatalogOption.getOrElse(SESSION_CATALOG_NAME)).loadTable("t1") assert(table.name() === "t1", "Table identifier was wrong") assert(table.partitioning().length === 1, "Partitioning should not be empty") assert(table.partitioning().head.references().head.fieldNames().head === "part", "Partitioning was incorrect") assert(table.schema() === df.schema.asNullable, "Schema did not match") + + val dfr = spark.read.format(format).option("name", "t1") + withCatalogOption.foreach(cName => dfr.option("catalog", cName)) + checkAnswer(dfr.load(), df.toDF()) } } @@ -92,7 +101,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with sql("create table t1 (id bigint) using foo") val df = spark.range(10) intercept[TableAlreadyExistsException] { - val dfw = df.write.option("name", "t1") + val dfw = df.write.format(format).option("name", "t1") withCatalogOption.foreach(cName => dfw.option("catalog", cName)) dfw.save() } @@ -102,7 +111,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with sql("create table t1 (id bigint) using foo") val df = spark.range(10).withColumn("part", 'id % 5) intercept[TableAlreadyExistsException] { - val dfw = df.write.mode(SaveMode.Ignore).option("name", "t1") + val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") withCatalogOption.foreach(cName => dfw.option("catalog", cName)) dfw.save() } @@ -113,9 +122,9 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } } - dataFrameWriterTests(None) + testWithDifferentCatalogs(None) - dataFrameWriterTests(Some(catalogName)) + testWithDifferentCatalogs(Some(catalogName)) } class CatalogSupportingInMemoryTableProvider From b94bfc57d68a80ead60de8bb36b8315eb994e31f Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 20 Dec 2019 08:29:14 -0800 Subject: [PATCH 07/12] Update SupportsCatalogOptionsSuite.scala --- .../sql/connector/SupportsCatalogOptionsSuite.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 1381811d7e81..1bb6771318fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -62,7 +62,8 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with def testWithDifferentCatalogs(withCatalogOption: Option[String]): Unit = { Seq(SaveMode.ErrorIfExists, SaveMode.Ignore).foreach { saveMode => - test(s"save works with $saveMode - no table, no partitioning, session catalog") { + test(s"save works with $saveMode - no table, no partitioning, session catalog, " + + s"withCatalog: ${withCatalogOption.isDefined}") { val df = spark.range(10) val dfw = df.write.format(format).mode(saveMode).option("name", "t1") withCatalogOption.foreach(cName => dfw.option("catalog", cName)) @@ -78,7 +79,8 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with checkAnswer(dfr.load(), df.toDF()) } - test(s"save works with $saveMode - no table, with partitioning, session catalog") { + test(s"save works with $saveMode - no table, with partitioning, session catalog, " + + s"withCatalog: ${withCatalogOption.isDefined}") { val df = spark.range(10).withColumn("part", 'id % 5) val dfw = df.write.format(format).mode(saveMode).option("name", "t1").partitionBy("part") withCatalogOption.foreach(cName => dfw.option("catalog", cName)) @@ -97,7 +99,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } } - test("save fails with ErrorIfExists if table exists") { + test(s"save fails with ErrorIfExists if table exists, withCatalog: ${withCatalogOption.isDefined}") { sql("create table t1 (id bigint) using foo") val df = spark.range(10) intercept[TableAlreadyExistsException] { @@ -107,7 +109,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } } - test("Ignore mode if table exists") { + test(s"Ignore mode if table exists, withCatalog: ${withCatalogOption.isDefined}") { sql("create table t1 (id bigint) using foo") val df = spark.range(10).withColumn("part", 'id % 5) intercept[TableAlreadyExistsException] { From d8fd371ac4c64c17cde00cdb8bc840bd1ee537dd Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 20 Dec 2019 12:34:59 -0800 Subject: [PATCH 08/12] Address comments --- .../apache/spark/sql/DataFrameWriter.scala | 6 +- .../SupportsCatalogOptionsSuite.scala | 129 ++++++++++-------- 2 files changed, 75 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index d8fb71cf3082..80a9c868f733 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -277,11 +277,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } case other if classOf[SupportsCatalogOptions].isAssignableFrom(provider.getClass) => - val catalogOptions = provider.asInstanceOf[SupportsCatalogOptions] - val ident = catalogOptions.extractIdentifier(dsOptions) + val supportsExtract = provider.asInstanceOf[SupportsCatalogOptions] + val ident = supportsExtract.extractIdentifier(dsOptions) val sessionState = df.sparkSession.sessionState val catalog = CatalogV2Util.getTableProviderCatalog( - catalogOptions, sessionState.catalogManager, dsOptions) + supportsExtract, sessionState.catalogManager, dsOptions) val location = Option(dsOptions.get("path")).map(TableCatalog.PROP_LOCATION -> _) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 1381811d7e81..bf91a190dcd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -60,71 +60,86 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with spark.conf.unset(s"spark.sql.catalog.$catalogName") } - def testWithDifferentCatalogs(withCatalogOption: Option[String]): Unit = { - Seq(SaveMode.ErrorIfExists, SaveMode.Ignore).foreach { saveMode => - test(s"save works with $saveMode - no table, no partitioning, session catalog") { - val df = spark.range(10) - val dfw = df.write.format(format).mode(saveMode).option("name", "t1") - withCatalogOption.foreach(cName => dfw.option("catalog", cName)) - dfw.save() - - val table = catalog(withCatalogOption.getOrElse(SESSION_CATALOG_NAME)).loadTable("t1") - assert(table.name() === "t1", "Table identifier was wrong") - assert(table.partitioning().isEmpty, "Partitioning should be empty") - assert(table.schema() === df.schema.asNullable, "Schema did not match") - - val dfr = spark.read.format(format).option("name", "t1") - withCatalogOption.foreach(cName => dfr.option("catalog", cName)) - checkAnswer(dfr.load(), df.toDF()) - } - - test(s"save works with $saveMode - no table, with partitioning, session catalog") { - val df = spark.range(10).withColumn("part", 'id % 5) - val dfw = df.write.format(format).mode(saveMode).option("name", "t1").partitionBy("part") - withCatalogOption.foreach(cName => dfw.option("catalog", cName)) - dfw.save() - - val table = catalog(withCatalogOption.getOrElse(SESSION_CATALOG_NAME)).loadTable("t1") - assert(table.name() === "t1", "Table identifier was wrong") - assert(table.partitioning().length === 1, "Partitioning should not be empty") - assert(table.partitioning().head.references().head.fieldNames().head === "part", - "Partitioning was incorrect") - assert(table.schema() === df.schema.asNullable, "Schema did not match") - - val dfr = spark.read.format(format).option("name", "t1") - withCatalogOption.foreach(cName => dfr.option("catalog", cName)) - checkAnswer(dfr.load(), df.toDF()) - } + private def testCreateAndRead( + saveMode: SaveMode, + withCatalogOption: Option[String], + partitionBy: Seq[String]): Unit = { + val df = spark.range(10).withColumn("part", 'id % 5) + val dfw = df.write.format(format).mode(saveMode).option("name", "t1") + withCatalogOption.foreach(cName => dfw.option("catalog", cName)) + dfw.partitionBy(partitionBy: _*).save() + + val table = catalog(withCatalogOption.getOrElse(SESSION_CATALOG_NAME)).loadTable("t1") + assert(table.name() === "t1", "Table identifier was wrong") + assert(table.partitioning().length === partitionBy.length, "Partitioning did not match") + assert(table.partitioning().map(_.references().head.fieldNames().head) === partitionBy, + "Partitioning was incorrect") + assert(table.schema() === df.schema.asNullable, "Schema did not match") + + val dfr = spark.read.format(format).option("name", "t1") + withCatalogOption.foreach(cName => dfr.option("catalog", cName)) + checkAnswer(dfr.load(), df.toDF()) + } + + test(s"save works with ErrorIfExists - no table, no partitioning, session catalog") { + testCreateAndRead(SaveMode.ErrorIfExists, None, Nil) + } + + test(s"save works with ErrorIfExists - no table, with partitioning, session catalog") { + testCreateAndRead(SaveMode.ErrorIfExists, None, Seq("part")) + } + + test(s"save works with Ignore - no table, no partitioning, testcat catalog") { + testCreateAndRead(SaveMode.ErrorIfExists, Some(catalogName), Nil) + } + + test(s"save works with Ignore - no table, with partitioning, testcat catalog") { + testCreateAndRead(SaveMode.ErrorIfExists, Some(catalogName), Seq("part")) + } + + test("save fails with ErrorIfExists if table exists - session catalog") { + sql("create table t1 (id bigint) using foo") + val df = spark.range(10) + intercept[TableAlreadyExistsException] { + val dfw = df.write.format(format).option("name", "t1") + dfw.save() } + } - test("save fails with ErrorIfExists if table exists") { - sql("create table t1 (id bigint) using foo") - val df = spark.range(10) - intercept[TableAlreadyExistsException] { - val dfw = df.write.format(format).option("name", "t1") - withCatalogOption.foreach(cName => dfw.option("catalog", cName)) - dfw.save() - } + test("save fails with ErrorIfExists if table exists - testcat catalog") { + sql("create table t1 (id bigint) using foo") + val df = spark.range(10) + intercept[TableAlreadyExistsException] { + val dfw = df.write.format(format).option("name", "t1").option("catalog", catalogName) + dfw.save() } + } - test("Ignore mode if table exists") { - sql("create table t1 (id bigint) using foo") - val df = spark.range(10).withColumn("part", 'id % 5) - intercept[TableAlreadyExistsException] { - val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") - withCatalogOption.foreach(cName => dfw.option("catalog", cName)) - dfw.save() - } - - val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") - assert(table.partitioning().isEmpty, "Partitioning should be empty") - assert(table.schema() === new StructType().add("id", LongType), "Schema did not match") + test("Ignore mode if table exists - session catalog") { + sql("create table t1 (id bigint) using foo") + val df = spark.range(10).withColumn("part", 'id % 5) + intercept[TableAlreadyExistsException] { + val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") + dfw.save() } + + val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") + assert(table.partitioning().isEmpty, "Partitioning should be empty") + assert(table.schema() === new StructType().add("id", LongType), "Schema did not match") } - testWithDifferentCatalogs(None) + test("Ignore mode if table exists - testcat catalog") { + sql("create table t1 (id bigint) using foo") + val df = spark.range(10).withColumn("part", 'id % 5) + intercept[TableAlreadyExistsException] { + val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") + dfw.option("catalog", catalogName).save() + } - testWithDifferentCatalogs(Some(catalogName)) + val table = catalog(catalogName).loadTable("t1") + assert(table.partitioning().isEmpty, "Partitioning should be empty") + assert(table.schema() === new StructType().add("id", LongType), "Schema did not match") + } } class CatalogSupportingInMemoryTableProvider From 746e0d1d3a11af21604dc8acc9d098fc721a0bb7 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 20 Dec 2019 16:56:05 -0800 Subject: [PATCH 09/12] implement for append and overwrite as well --- .../apache/spark/sql/DataFrameWriter.scala | 76 ++++++++++++------- .../SupportsCatalogOptionsSuite.scala | 16 ++-- 2 files changed, 57 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 80a9c868f733..0c55f5ca52bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -258,26 +258,42 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val dsOptions = new CaseInsensitiveStringMap(options.asJava) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ - provider.getTable(dsOptions) match { - case table: SupportsWrite if table.supports(BATCH_WRITE) => - lazy val relation = DataSourceV2Relation.create(table, dsOptions) - mode match { - case SaveMode.Append => - checkPartitioningMatchesV2Table(table) - runCommand(df.sparkSession, "save") { - AppendData.byName(relation, df.logicalPlan, extraOptions.toMap) - } + mode match { + case SaveMode.Append | SaveMode.Overwrite => + val table = provider match { + case supportsExtract: SupportsCatalogOptions => + val ident = supportsExtract.extractIdentifier(dsOptions) + val sessionState = df.sparkSession.sessionState + val catalog = CatalogV2Util.getTableProviderCatalog( + supportsExtract, sessionState.catalogManager, dsOptions) - case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => - checkPartitioningMatchesV2Table(table) - // truncate the table - runCommand(df.sparkSession, "save") { - OverwriteByExpression.byName( - relation, df.logicalPlan, Literal(true), extraOptions.toMap) - } + catalog.loadTable(ident) + case tableProvider: TableProvider => tableProvider.getTable(dsOptions) + case _ => + // Streaming also uses the data source V2 API. So it may be that the data source + // implements v2, but has no v2 implementation for batch writes. In that case, we fall + // back to saving as though it's a V1 source. + return saveToV1Source() + } + + val relation = DataSourceV2Relation.create(table, dsOptions) + checkPartitioningMatchesV2Table(table) + if (mode == SaveMode.Append) { + runCommand(df.sparkSession, "save") { + AppendData.byName(relation, df.logicalPlan, extraOptions.toMap) + } + } else { + // Truncate the table. TableCapabilityCheck will throw a nice exception if this + // isn't supported + runCommand(df.sparkSession, "save") { + OverwriteByExpression.byName( + relation, df.logicalPlan, Literal(true), extraOptions.toMap) + } + } - case other if classOf[SupportsCatalogOptions].isAssignableFrom(provider.getClass) => - val supportsExtract = provider.asInstanceOf[SupportsCatalogOptions] + case create => + provider match { + case supportsExtract: SupportsCatalogOptions => val ident = supportsExtract.extractIdentifier(dsOptions) val sessionState = df.sparkSession.sessionState val catalog = CatalogV2Util.getTableProviderCatalog( @@ -293,20 +309,22 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ location, extraOptions.toMap, - ignoreIfExists = other == SaveMode.Ignore) + ignoreIfExists = create == SaveMode.Ignore) + } + case tableProvider: TableProvider => + if (tableProvider.getTable(dsOptions).supports(BATCH_WRITE)) { + throw new AnalysisException(s"TableProvider implementation $source cannot be " + + s"written with $create mode, please use Append or Overwrite " + + "modes instead.") + } else { + // Streaming also uses the data source V2 API. So it may be that the data source + // implements v2, but has no v2 implementation for batch writes. In that case, we + // fallback to saving as though it's a V1 source. + saveToV1Source() } - - case other => - throw new AnalysisException(s"TableProvider implementation $source cannot be " + - s"written with $other mode, please use Append or Overwrite " + - "modes instead.") } - - // Streaming also uses the data source V2 API. So it may be that the data source implements - // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving - // as though it's a V1 source. - case _ => saveToV1Source() } + } else { saveToV1Source() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 079fa21a3a58..51a5a3bf15c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.connector +import java.util + import scala.language.implicitConversions +import scala.util.Try import org.scalatest.BeforeAndAfter @@ -25,6 +28,7 @@ import org.apache.spark.sql.{QueryTest, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{LongType, StructType} @@ -54,8 +58,8 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with override def afterEach(): Unit = { super.afterEach() - catalog(SESSION_CATALOG_NAME).clearTables() - catalog(catalogName).clearTables() + Try(catalog(SESSION_CATALOG_NAME).clearTables()) + Try(catalog(catalogName).clearTables()) spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) spark.conf.unset(s"spark.sql.catalog.$catalogName") } @@ -98,7 +102,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } test("save fails with ErrorIfExists if table exists - session catalog") { - sql("create table t1 (id bigint) using foo") + sql(s"create table t1 (id bigint) using $format") val df = spark.range(10) intercept[TableAlreadyExistsException] { val dfw = df.write.format(format).option("name", "t1") @@ -107,7 +111,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } test("save fails with ErrorIfExists if table exists - testcat catalog") { - sql("create table t1 (id bigint) using foo") + sql(s"create table testcat.t1 (id bigint) using $format") val df = spark.range(10) intercept[TableAlreadyExistsException] { val dfw = df.write.format(format).option("name", "t1").option("catalog", catalogName) @@ -116,7 +120,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } test("Ignore mode if table exists - session catalog") { - sql("create table t1 (id bigint) using foo") + sql(s"create table t1 (id bigint) using $format") val df = spark.range(10).withColumn("part", 'id % 5) intercept[TableAlreadyExistsException] { val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") @@ -129,7 +133,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } test("Ignore mode if table exists - testcat catalog") { - sql("create table t1 (id bigint) using foo") + sql(s"create table testcat.t1 (id bigint) using $format") val df = spark.range(10).withColumn("part", 'id % 5) intercept[TableAlreadyExistsException] { val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") From 8827d9364c1774d6606523eb74582aa9665eb72b Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sat, 21 Dec 2019 13:46:20 -0800 Subject: [PATCH 10/12] more tests --- .../spark/sql/kafka010/KafkaSinkSuite.scala | 2 +- .../SupportsCatalogOptionsSuite.scala | 78 +++++++++++++------ .../connector/TestV2SessionCatalogBase.scala | 5 ++ 3 files changed, 61 insertions(+), 24 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index ac242ba3d135..cbba6d38739e 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -505,7 +505,7 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { testUtils.createTopic(topic) val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") - Seq(SaveMode.Ignore, SaveMode.Overwrite).foreach { mode => + Seq(SaveMode.Overwrite).foreach { mode => val ex = intercept[AnalysisException] { df.write .format("kafka") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 51a5a3bf15c3..23a3d18fc57e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -24,9 +24,9 @@ import scala.util.Try import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{QueryTest, SaveMode} +import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException -import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions} +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION @@ -41,8 +41,8 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with private val catalogName = "testcat" private val format = classOf[CatalogSupportingInMemoryTableProvider].getName - private def catalog(name: String): InMemoryTableSessionCatalog = { - spark.sessionState.catalogManager.catalog(name).asInstanceOf[InMemoryTableSessionCatalog] + private def catalog(name: String): TableCatalog = { + spark.sessionState.catalogManager.catalog(name).asInstanceOf[TableCatalog] } private implicit def stringToIdentifier(value: String): Identifier = { @@ -53,13 +53,14 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with spark.conf.set( V2_SESSION_CATALOG_IMPLEMENTATION.key, classOf[InMemoryTableSessionCatalog].getName) spark.conf.set( - s"spark.sql.catalog.$catalogName", classOf[InMemoryTableSessionCatalog].getName) + s"spark.sql.catalog.$catalogName", classOf[InMemoryTableCatalog].getName) } override def afterEach(): Unit = { super.afterEach() - Try(catalog(SESSION_CATALOG_NAME).clearTables()) - Try(catalog(catalogName).clearTables()) + Try(catalog(SESSION_CATALOG_NAME).asInstanceOf[InMemoryTableSessionCatalog].clearTables()) + catalog(catalogName).listTables(Array.empty).foreach( + catalog(catalogName).dropTable(_)) spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key) spark.conf.unset(s"spark.sql.catalog.$catalogName") } @@ -74,15 +75,14 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with dfw.partitionBy(partitionBy: _*).save() val table = catalog(withCatalogOption.getOrElse(SESSION_CATALOG_NAME)).loadTable("t1") - assert(table.name() === "t1", "Table identifier was wrong") + val namespace = withCatalogOption.getOrElse("default") + assert(table.name() === s"$namespace.t1", "Table identifier was wrong") assert(table.partitioning().length === partitionBy.length, "Partitioning did not match") assert(table.partitioning().map(_.references().head.fieldNames().head) === partitionBy, "Partitioning was incorrect") assert(table.schema() === df.schema.asNullable, "Schema did not match") - val dfr = spark.read.format(format).option("name", "t1") - withCatalogOption.foreach(cName => dfr.option("catalog", cName)) - checkAnswer(dfr.load(), df.toDF()) + checkAnswer(load("t1", withCatalogOption), df.toDF()) } test(s"save works with ErrorIfExists - no table, no partitioning, session catalog") { @@ -94,11 +94,11 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } test(s"save works with Ignore - no table, no partitioning, testcat catalog") { - testCreateAndRead(SaveMode.ErrorIfExists, Some(catalogName), Nil) + testCreateAndRead(SaveMode.Ignore, Some(catalogName), Nil) } test(s"save works with Ignore - no table, with partitioning, testcat catalog") { - testCreateAndRead(SaveMode.ErrorIfExists, Some(catalogName), Seq("part")) + testCreateAndRead(SaveMode.Ignore, Some(catalogName), Seq("part")) } test("save fails with ErrorIfExists if table exists - session catalog") { @@ -111,7 +111,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } test("save fails with ErrorIfExists if table exists - testcat catalog") { - sql(s"create table testcat.t1 (id bigint) using $format") + sql(s"create table $catalogName.t1 (id bigint) using $format") val df = spark.range(10) intercept[TableAlreadyExistsException] { val dfw = df.write.format(format).option("name", "t1").option("catalog", catalogName) @@ -122,27 +122,59 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with test("Ignore mode if table exists - session catalog") { sql(s"create table t1 (id bigint) using $format") val df = spark.range(10).withColumn("part", 'id % 5) - intercept[TableAlreadyExistsException] { - val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") - dfw.save() - } + val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") + dfw.save() val table = catalog(SESSION_CATALOG_NAME).loadTable("t1") assert(table.partitioning().isEmpty, "Partitioning should be empty") assert(table.schema() === new StructType().add("id", LongType), "Schema did not match") + assert(load("t1", None).count() === 0) } test("Ignore mode if table exists - testcat catalog") { - sql(s"create table testcat.t1 (id bigint) using $format") + sql(s"create table $catalogName.t1 (id bigint) using $format") val df = spark.range(10).withColumn("part", 'id % 5) - intercept[TableAlreadyExistsException] { - val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") - dfw.option("catalog", catalogName).save() - } + val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") + dfw.option("catalog", catalogName).save() val table = catalog(catalogName).loadTable("t1") assert(table.partitioning().isEmpty, "Partitioning should be empty") assert(table.schema() === new StructType().add("id", LongType), "Schema did not match") + assert(load("t1", Some(catalogName)).count() === 0) + } + + test("append and overwrite modes - session catalog") { + sql(s"create table t1 (id bigint) using $format") + val df = spark.range(10) + df.write.format(format).option("name", "t1").mode(SaveMode.Append).save() + + checkAnswer(load("t1", None), df.toDF()) + + val df2 = spark.range(10, 20) + df2.write.format(format).option("name", "t1").mode(SaveMode.Overwrite).save() + + checkAnswer(load("t1", None), df2.toDF()) + } + + test("append and overwrite modes - testcat catalog") { + sql(s"create table $catalogName.t1 (id bigint) using $format") + val df = spark.range(10) + df.write.format(format).option("name", "t1").option("catalog", catalogName) + .mode(SaveMode.Append).save() + + checkAnswer(load("t1", Some(catalogName)), df.toDF()) + + val df2 = spark.range(10, 20) + df2.write.format(format).option("name", "t1").option("catalog", catalogName) + .mode(SaveMode.Overwrite).save() + + checkAnswer(load("t1", Some(catalogName)), df2.toDF()) + } + + private def load(name: String, catalogOpt: Option[String]): DataFrame = { + val dfr = spark.read.format(format).option("name", "t1") + catalogOpt.foreach(cName => dfr.option("catalog", cName)) + dfr.load() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala index d03294cb4067..3f6ac0b7f8d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala @@ -74,6 +74,11 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating t } + override def dropTable(ident: Identifier): Boolean = { + tables.remove(fullIdentifier(ident)) + super.dropTable(ident) + } + def clearTables(): Unit = { assert(!tables.isEmpty, "Tables were empty, maybe didn't use the session catalog code path?") tables.keySet().asScala.foreach(super.dropTable) From 12f4ce48a505b12988292c44dedaf05c7abf4136 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 23 Dec 2019 11:19:50 -0800 Subject: [PATCH 11/12] Address comments --- .../spark/sql/kafka010/KafkaSinkSuite.scala | 15 ++++++++----- .../apache/spark/sql/DataFrameWriter.scala | 22 +++++++++++-------- .../SupportsCatalogOptionsSuite.scala | 10 ++++++++- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index cbba6d38739e..9426c9708a8d 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicInteger import scala.reflect.ClassTag +import scala.util.Try import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.clients.producer.internals.DefaultPartitioner @@ -500,12 +501,12 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { TestUtils.assertExceptionMsg(ex, "null topic present in the data") } - protected def testUnsupportedSaveModes(msg: (SaveMode) => String): Unit = { + protected def testUnsupportedSaveModes(msg: (SaveMode) => Seq[String]): Unit = { val topic = newTopic() testUtils.createTopic(topic) val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") - Seq(SaveMode.Overwrite).foreach { mode => + Seq(SaveMode.Ignore, SaveMode.Overwrite).foreach { mode => val ex = intercept[AnalysisException] { df.write .format("kafka") @@ -513,7 +514,10 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { .mode(mode) .save() } - TestUtils.assertExceptionMsg(ex, msg(mode)) + val errorChecks = msg(mode).map(m => Try(TestUtils.assertExceptionMsg(ex, m))) + if (!errorChecks.exists(_.isSuccess)) { + fail("Error messages not found in exception trace") + } } } @@ -541,7 +545,7 @@ class KafkaSinkBatchSuiteV1 extends KafkaSinkBatchSuiteBase { .set(SQLConf.USE_V1_SOURCE_LIST, "kafka") test("batch - unsupported save modes") { - testUnsupportedSaveModes((mode) => s"Save mode ${mode.name} not allowed for Kafka") + testUnsupportedSaveModes((mode) => s"Save mode ${mode.name} not allowed for Kafka" :: Nil) } } @@ -552,7 +556,8 @@ class KafkaSinkBatchSuiteV2 extends KafkaSinkBatchSuiteBase { .set(SQLConf.USE_V1_SOURCE_LIST, "") test("batch - unsupported save modes") { - testUnsupportedSaveModes((mode) => s"cannot be written with ${mode.name} mode") + testUnsupportedSaveModes((mode) => + Seq(s"cannot be written with ${mode.name} mode", "does not support truncate")) } test("generic - write big data with small producer buffer") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 0c55f5ca52bc..3dda3f986355 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -268,12 +268,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { supportsExtract, sessionState.catalogManager, dsOptions) catalog.loadTable(ident) - case tableProvider: TableProvider => tableProvider.getTable(dsOptions) - case _ => - // Streaming also uses the data source V2 API. So it may be that the data source - // implements v2, but has no v2 implementation for batch writes. In that case, we fall - // back to saving as though it's a V1 source. - return saveToV1Source() + case tableProvider: TableProvider => + val t = tableProvider.getTable(dsOptions) + if (t.supports(BATCH_WRITE)) { + t + } else { + // Streaming also uses the data source V2 API. So it may be that the data source + // implements v2, but has no v2 implementation for batch writes. In that case, we + // fall back to saving as though it's a V1 source. + return saveToV1Source() + } } val relation = DataSourceV2Relation.create(table, dsOptions) @@ -291,7 +295,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - case create => + case createMode => provider match { case supportsExtract: SupportsCatalogOptions => val ident = supportsExtract.extractIdentifier(dsOptions) @@ -309,12 +313,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { df.queryExecution.analyzed, Map(TableCatalog.PROP_PROVIDER -> source) ++ location, extraOptions.toMap, - ignoreIfExists = create == SaveMode.Ignore) + ignoreIfExists = createMode == SaveMode.Ignore) } case tableProvider: TableProvider => if (tableProvider.getTable(dsOptions).supports(BATCH_WRITE)) { throw new AnalysisException(s"TableProvider implementation $source cannot be " + - s"written with $create mode, please use Append or Overwrite " + + s"written with $createMode mode, please use Append or Overwrite " + "modes instead.") } else { // Streaming also uses the data source V2 API. So it may be that the data source diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 23a3d18fc57e..7fd4cc113aa6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME -import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{LongType, StructType} @@ -78,6 +78,14 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with val namespace = withCatalogOption.getOrElse("default") assert(table.name() === s"$namespace.t1", "Table identifier was wrong") assert(table.partitioning().length === partitionBy.length, "Partitioning did not match") + if (partitionBy.nonEmpty) { + table.partitioning.head match { + case IdentityTransform(FieldReference(field)) => + assert(field === Seq(partitionBy.head), "Partitioning column did not match") + case otherTransform => + fail(s"Unexpected partitioning ${otherTransform.describe()} received") + } + } assert(table.partitioning().map(_.references().head.fieldNames().head) === partitionBy, "Partitioning was incorrect") assert(table.schema() === df.schema.asNullable, "Schema did not match") From 963133edc5004a465347fd7b1c3edfffd6ff5a8d Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 9 Jan 2020 07:11:52 -0800 Subject: [PATCH 12/12] address comments --- .../catalog/SupportsCatalogOptions.java | 2 +- .../org/apache/spark/sql/DataFrameReader.scala | 3 +++ .../org/apache/spark/sql/DataFrameWriter.scala | 7 +++---- .../connector/SupportsCatalogOptionsSuite.scala | 17 +++++++++++++++++ 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java index 2c55c7360641..5225b12788c4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java @@ -48,6 +48,6 @@ public interface SupportsCatalogOptions extends TableProvider { * topic name, etc. It's an immutable case-insensitive string-to-string map. */ default String extractCatalog(CaseInsensitiveStringMap options) { - return null; + return CatalogManager.SESSION_CATALOG_NAME(); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c67104384be5..30d0c851964d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -216,6 +216,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) val table = provider match { + case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty => + throw new IllegalArgumentException( + s"$source does not support user specified schema. Please don't specify the schema.") case hasCatalog: SupportsCatalogOptions => val ident = hasCatalog.extractIdentifier(dsOptions) val catalog = CatalogV2Util.getTableProviderCatalog( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3dda3f986355..998ec9ebdff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Util, Identifier, SupportsCatalogOptions, SupportsWrite, Table, TableCatalog, TableProvider, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, SupportsWrite, Table, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LiteralValue, Transform} import org.apache.spark.sql.execution.SQLExecution @@ -661,9 +661,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val partitioning = partitioningColumns.map { colNames => colNames.map(name => IdentityTransform(FieldReference(name))) }.getOrElse(Seq.empty[Transform]) - val bucketing = bucketColumnNames.map { cols => - Seq(BucketTransform(LiteralValue(numBuckets.get, IntegerType), cols.map(FieldReference(_)))) - }.getOrElse(Seq.empty[Transform]) + val bucketing = + getBucketSpec.map(spec => CatalogV2Implicits.BucketSpecHelper(spec).asTransform).toSeq partitioning ++ bucketing } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 7fd4cc113aa6..0148bb07ee96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -179,6 +179,23 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with checkAnswer(load("t1", Some(catalogName)), df2.toDF()) } + test("fail on user specified schema when reading - session catalog") { + sql(s"create table t1 (id bigint) using $format") + val e = intercept[IllegalArgumentException] { + spark.read.format(format).option("name", "t1").schema("id bigint").load() + } + assert(e.getMessage.contains("not support user specified schema")) + } + + test("fail on user specified schema when reading - testcat catalog") { + sql(s"create table $catalogName.t1 (id bigint) using $format") + val e = intercept[IllegalArgumentException] { + spark.read.format(format).option("name", "t1").option("catalog", catalogName) + .schema("id bigint").load() + } + assert(e.getMessage.contains("not support user specified schema")) + } + private def load(name: String, catalogOpt: Option[String]): DataFrame = { val dfr = spark.read.format(format).option("name", "t1") catalogOpt.foreach(cName => dfr.option("catalog", cName))