diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java index c44a12b174f4..eccf2892b039 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java @@ -89,5 +89,14 @@ public enum TableCapability { /** * Signals that the table accepts input of any schema in a write operation. */ - ACCEPT_ANY_SCHEMA + ACCEPT_ANY_SCHEMA, + + /** + * Signals that the table supports append writes using the V1 InsertableRelation interface. + *

+ * Tables that return this capability must create a V1WriteBuilder and may also support additional + * write modes, like {@link #TRUNCATE}, and {@link #OVERWRITE_BY_FILTER}, but cannot support + * {@link #OVERWRITE_DYNAMIC}. + */ + V1_BATCH_WRITE } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 585fe06ce4ce..5fa7f49a1795 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.UUID + import scala.collection.JavaConverters._ import scala.collection.mutable @@ -29,8 +31,10 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.v2.TableCapability import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} +import org.apache.spark.sql.sources.v2.writer.V1WriteBuilder import org.apache.spark.sql.util.CaseInsensitiveStringMap object DataSourceV2Strategy extends Strategy with PredicateHelper { @@ -169,10 +173,10 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { catalog match { case staging: StagingTableCatalog => AtomicCreateTableAsSelectExec( - staging, ident, parts, planLater(query), props, writeOptions, ifNotExists) :: Nil + staging, ident, parts, query, planLater(query), props, writeOptions, ifNotExists) :: Nil case _ => CreateTableAsSelectExec( - catalog, ident, parts, planLater(query), props, writeOptions, ifNotExists) :: Nil + catalog, ident, parts, query, planLater(query), props, writeOptions, ifNotExists) :: Nil } case ReplaceTable(catalog, ident, schema, parts, props, orCreate) => @@ -191,6 +195,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { staging, ident, parts, + query, planLater(query), props, writeOptions, @@ -200,6 +205,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { catalog, ident, parts, + query, planLater(query), props, writeOptions, @@ -207,7 +213,12 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { } case AppendData(r: DataSourceV2Relation, query, _) => - AppendDataExec(r.table.asWritable, r.options, planLater(query)) :: Nil + r.table.asWritable match { + case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => + AppendDataExecV1(v1, r.options, query) :: Nil + case v2 => + AppendDataExec(v2, r.options, planLater(query)) :: Nil + } case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) => // fail if any filter cannot be converted. correctness depends on removing all matching data. @@ -215,9 +226,12 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse( throw new AnalysisException(s"Cannot translate expression to source filter: $filter")) }.toArray - - OverwriteByExpressionExec( - r.table.asWritable, filters, r.options, planLater(query)) :: Nil + r.table.asWritable match { + case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => + OverwriteByExpressionExecV1(v1, filters, r.options, query) :: Nil + case v2 => + OverwriteByExpressionExec(v2, filters, r.options, planLater(query)) :: Nil + } case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) => OverwritePartitionsDynamicExec(r.table.asWritable, r.options, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala new file mode 100644 index 000000000000..2f05ff3a7c2e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.UUID + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Dataset, SaveMode} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.sources.{AlwaysTrue, CreatableRelationProvider, Filter, InsertableRelation} +import org.apache.spark.sql.sources.v2.{SupportsWrite, Table} +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Physical plan node for append into a v2 table using V1 write interfaces. + * + * Rows in the output data set are appended. + */ +case class AppendDataExecV1( + table: SupportsWrite, + writeOptions: CaseInsensitiveStringMap, + plan: LogicalPlan) extends V1FallbackWriters { + + override protected def doExecute(): RDD[InternalRow] = { + writeWithV1(newWriteBuilder().buildForV1Write()) + } +} + +/** + * Physical plan node for overwrite into a v2 table with V1 write interfaces. Note that when this + * interface is used, the atomicity of the operation depends solely on the target data source. + * + * Overwrites data in a table matched by a set of filters. Rows matching all of the filters will be + * deleted and rows in the output data set are appended. + * + * This plan is used to implement SaveMode.Overwrite. The behavior of SaveMode.Overwrite is to + * truncate the table -- delete all rows -- and append the output data set. This uses the filter + * AlwaysTrue to delete all rows. + */ +case class OverwriteByExpressionExecV1( + table: SupportsWrite, + deleteWhere: Array[Filter], + writeOptions: CaseInsensitiveStringMap, + plan: LogicalPlan) extends V1FallbackWriters { + + private def isTruncate(filters: Array[Filter]): Boolean = { + filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] + } + + override protected def doExecute(): RDD[InternalRow] = { + newWriteBuilder() match { + case builder: SupportsTruncate if isTruncate(deleteWhere) => + writeWithV1(builder.truncate().asV1Builder.buildForV1Write()) + + case builder: SupportsOverwrite => + writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write()) + + case _ => + throw new SparkException(s"Table does not support overwrite by expression: $table") + } + } +} + +/** Some helper interfaces that use V2 write semantics through the V1 writer interface. */ +sealed trait V1FallbackWriters extends SupportsV1Write { + override def output: Seq[Attribute] = Nil + override final def children: Seq[SparkPlan] = Nil + + def table: SupportsWrite + def writeOptions: CaseInsensitiveStringMap + + protected implicit class toV1WriteBuilder(builder: WriteBuilder) { + def asV1Builder: V1WriteBuilder = builder match { + case v1: V1WriteBuilder => v1 + case other => throw new IllegalStateException( + s"The returned writer ${other} was no longer a V1WriteBuilder.") + } + } + + protected def newWriteBuilder(): V1WriteBuilder = { + val writeBuilder = table.newWriteBuilder(writeOptions) + .withInputDataSchema(plan.schema) + .withQueryId(UUID.randomUUID().toString) + writeBuilder.asV1Builder + } +} + +/** + * A trait that allows Tables that use V1 Writer interfaces to append data. + */ +trait SupportsV1Write extends SparkPlan { + // TODO: We should be able to work on SparkPlans at this point. + def plan: LogicalPlan + + protected def writeWithV1(relation: InsertableRelation): RDD[InternalRow] = { + relation.insert(Dataset.ofRows(sqlContext.sparkSession, plan), overwrite = false) + sparkContext.emptyRDD + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 86b64cb8835a..39269a3f4354 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.sources.{AlwaysTrue, Filter} import org.apache.spark.sql.sources.v2.{StagedTable, SupportsWrite} -import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LongAccumulator, Utils} @@ -63,10 +63,11 @@ case class CreateTableAsSelectExec( catalog: TableCatalog, ident: Identifier, partitioning: Seq[Transform], + plan: LogicalPlan, query: SparkPlan, properties: Map[String, String], writeOptions: CaseInsensitiveStringMap, - ifNotExists: Boolean) extends V2TableWriteExec { + ifNotExists: Boolean) extends V2TableWriteExec with SupportsV1Write { import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper @@ -83,12 +84,14 @@ case class CreateTableAsSelectExec( catalog.createTable( ident, query.schema, partitioning.toArray, properties.asJava) match { case table: SupportsWrite => - val batchWrite = table.newWriteBuilder(writeOptions) + val writeBuilder = table.newWriteBuilder(writeOptions) .withInputDataSchema(query.schema) .withQueryId(UUID.randomUUID().toString) - .buildForBatch() - doWrite(batchWrite) + writeBuilder match { + case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write()) + case v2 => writeWithV2(v2.buildForBatch()) + } case _ => // table does not support writes @@ -114,6 +117,7 @@ case class AtomicCreateTableAsSelectExec( catalog: StagingTableCatalog, ident: Identifier, partitioning: Seq[Transform], + plan: LogicalPlan, query: SparkPlan, properties: Map[String, String], writeOptions: CaseInsensitiveStringMap, @@ -147,10 +151,11 @@ case class ReplaceTableAsSelectExec( catalog: TableCatalog, ident: Identifier, partitioning: Seq[Transform], + plan: LogicalPlan, query: SparkPlan, properties: Map[String, String], writeOptions: CaseInsensitiveStringMap, - orCreate: Boolean) extends AtomicTableWriteExec { + orCreate: Boolean) extends V2TableWriteExec with SupportsV1Write { import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper @@ -173,12 +178,14 @@ case class ReplaceTableAsSelectExec( Utils.tryWithSafeFinallyAndFailureCallbacks({ createdTable match { case table: SupportsWrite => - val batchWrite = table.newWriteBuilder(writeOptions) + val writeBuilder = table.newWriteBuilder(writeOptions) .withInputDataSchema(query.schema) .withQueryId(UUID.randomUUID().toString) - .buildForBatch() - doWrite(batchWrite) + writeBuilder match { + case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write()) + case v2 => writeWithV2(v2.buildForBatch()) + } case _ => // table does not support writes @@ -207,6 +214,7 @@ case class AtomicReplaceTableAsSelectExec( catalog: StagingTableCatalog, ident: Identifier, partitioning: Seq[Transform], + plan: LogicalPlan, query: SparkPlan, properties: Map[String, String], writeOptions: CaseInsensitiveStringMap, @@ -242,8 +250,7 @@ case class AppendDataExec( query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def doExecute(): RDD[InternalRow] = { - val batchWrite = newWriteBuilder().buildForBatch() - doWrite(batchWrite) + writeWithV2(newWriteBuilder().buildForBatch()) } } @@ -268,18 +275,16 @@ case class OverwriteByExpressionExec( } override protected def doExecute(): RDD[InternalRow] = { - val batchWrite = newWriteBuilder() match { + newWriteBuilder() match { case builder: SupportsTruncate if isTruncate(deleteWhere) => - builder.truncate().buildForBatch() + writeWithV2(builder.truncate().buildForBatch()) case builder: SupportsOverwrite => - builder.overwrite(deleteWhere).buildForBatch() + writeWithV2(builder.overwrite(deleteWhere).buildForBatch()) case _ => throw new SparkException(s"Table does not support overwrite by expression: $table") } - - doWrite(batchWrite) } } @@ -298,15 +303,13 @@ case class OverwritePartitionsDynamicExec( query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def doExecute(): RDD[InternalRow] = { - val batchWrite = newWriteBuilder() match { + newWriteBuilder() match { case builder: SupportsDynamicOverwrite => - builder.overwriteDynamicPartitions().buildForBatch() + writeWithV2(builder.overwriteDynamicPartitions().buildForBatch()) case _ => throw new SparkException(s"Table does not support dynamic partition overwrite: $table") } - - doWrite(batchWrite) } } @@ -317,7 +320,7 @@ case class WriteToDataSourceV2Exec( def writeOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty() override protected def doExecute(): RDD[InternalRow] = { - doWrite(batchWrite) + writeWithV2(batchWrite) } } @@ -331,8 +334,8 @@ trait BatchWriteHelper { def newWriteBuilder(): WriteBuilder = { table.newWriteBuilder(writeOptions) - .withInputDataSchema(query.schema) - .withQueryId(UUID.randomUUID().toString) + .withInputDataSchema(query.schema) + .withQueryId(UUID.randomUUID().toString) } } @@ -347,7 +350,7 @@ trait V2TableWriteExec extends UnaryExecNode { override def child: SparkPlan = query override def output: Seq[Attribute] = Nil - protected def doWrite(batchWrite: BatchWrite): RDD[InternalRow] = { + protected def writeWithV2(batchWrite: BatchWrite): RDD[InternalRow] = { val writerFactory = batchWrite.createBatchWriterFactory() val useCommitCoordinator = batchWrite.useCommitCoordinator val rdd = query.execute() @@ -463,7 +466,7 @@ object DataWritingSparkTask extends Logging { } } -private[v2] trait AtomicTableWriteExec extends V2TableWriteExec { +private[v2] trait AtomicTableWriteExec extends V2TableWriteExec with SupportsV1Write { import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper protected def writeToStagedTable( @@ -473,14 +476,17 @@ private[v2] trait AtomicTableWriteExec extends V2TableWriteExec { Utils.tryWithSafeFinallyAndFailureCallbacks({ stagedTable match { case table: SupportsWrite => - val batchWrite = table.newWriteBuilder(writeOptions) + val writeBuilder = table.newWriteBuilder(writeOptions) .withInputDataSchema(query.schema) .withQueryId(UUID.randomUUID().toString) - .buildForBatch() - val writtenRows = doWrite(batchWrite) + val writtenRows = writeBuilder match { + case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write()) + case v2 => writeWithV2(v2.buildForBatch()) + } stagedTable.commitStagedChanges() writtenRows + case _ => // Table does not support writes - staged changes are also rolled back below. throw new SparkException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/v2/writer/V1WriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/v2/writer/V1WriteBuilder.scala new file mode 100644 index 000000000000..2a88555e2927 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/v2/writer/V1WriteBuilder.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer + +import org.apache.spark.annotation.{Experimental, Unstable} +import org.apache.spark.sql.sources.InsertableRelation +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite + +/** + * A trait that should be implemented by V1 DataSources that would like to leverage the DataSource + * V2 write code paths. The InsertableRelation will be used only to Append data. Other + * instances of the [[WriteBuilder]] interface such as [[SupportsOverwrite]], [[SupportsTruncate]] + * should be extended as well to support additional operations other than data appends. + * + * This interface is designed to provide Spark DataSources time to migrate to DataSource V2 and + * will be removed in a future Spark release. + * + * @since 3.0.0 + */ +@Experimental +@Unstable +trait V1WriteBuilder extends WriteBuilder { + + /** + * Creates an InsertableRelation that allows appending a DataFrame to a + * a destination (using data source-specific parameters). The insert method will only be + * called with `overwrite=false`. The DataSource should implement the overwrite behavior as + * part of the [[SupportsOverwrite]], and [[SupportsTruncate]] interfaces. + * + * @since 3.0.0 + */ + def buildForV1Write(): InsertableRelation + + // These methods cannot be implemented by a V1WriteBuilder. The super class will throw + // an Unsupported OperationException + override final def buildForBatch(): BatchWrite = super.buildForBatch() + + override final def buildForStreaming(): StreamingWrite = super.buildForStreaming() +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSessionCatalogSuite.scala index a104b8835c61..22ebfeea04d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -31,39 +31,90 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.utils.TestV2SessionCatalogBase import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap class DataSourceV2DataFrameSessionCatalogSuite + extends SessionCatalogTest[InMemoryTable, InMemoryTableSessionCatalog] { + + test("saveAsTable: Append mode should not fail if the table already exists " + + "and a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + val format = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE same_name(id LONG) USING $format") + spark.range(10).createTempView("same_name") + spark.range(20).write.format(v2Format).mode(SaveMode.Append).saveAsTable("same_name") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) + } + } + } + + test("saveAsTable with mode Overwrite should not fail if the table already exists " + + "and a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + sql(s"CREATE TABLE same_name(id LONG) USING $v2Format") + spark.range(10).createTempView("same_name") + spark.range(20).write.format(v2Format).mode(SaveMode.Overwrite).saveAsTable("same_name") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) + } + } + } +} + +class InMemoryTableProvider extends TableProvider { + override def getTable(options: CaseInsensitiveStringMap): Table = { + throw new UnsupportedOperationException("D'oh!") + } +} + +class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable] { + override def newTable( + name: String, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): InMemoryTable = { + new InMemoryTable(name, schema, partitions, properties) + } +} + +private[v2] trait SessionCatalogTest[T <: Table, Catalog <: TestV2SessionCatalogBase[T]] extends QueryTest with SharedSparkSession with BeforeAndAfter { - import testImplicits._ - private def catalog(name: String): CatalogPlugin = { + protected def catalog(name: String): CatalogPlugin = { spark.sessionState.catalogManager.catalog(name) } - private val v2Format = classOf[InMemoryTableProvider].getName + protected val v2Format = classOf[InMemoryTableProvider].getName + + protected val catalogClassName: String = classOf[InMemoryTableSessionCatalog].getName before { - spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[TestV2SessionCatalog].getName) + spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, catalogClassName) } override def afterEach(): Unit = { super.afterEach() - catalog("session").asInstanceOf[TestV2SessionCatalog].clearTables() + catalog("session").asInstanceOf[Catalog].clearTables() spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[V2SessionCatalog].getName) } - private def verifyTable(tableName: String, expected: DataFrame): Unit = { + protected def verifyTable(tableName: String, expected: DataFrame): Unit = { checkAnswer(spark.table(tableName), expected) checkAnswer(sql(s"SELECT * FROM $tableName"), expected) checkAnswer(sql(s"SELECT * FROM default.$tableName"), expected) checkAnswer(sql(s"TABLE $tableName"), expected) } + import testImplicits._ + test("saveAsTable: v2 table - table doesn't exist and default mode (ErrorIfExists)") { val t1 = "tbl" val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") @@ -90,20 +141,6 @@ class DataSourceV2DataFrameSessionCatalogSuite } } - test("saveAsTable: Append mode should not fail if the table already exists " + - "and a same-name temp view exist") { - withTable("same_name") { - withTempView("same_name") { - val format = spark.sessionState.conf.defaultDataSourceName - sql(s"CREATE TABLE same_name(id LONG) USING $format") - spark.range(10).createTempView("same_name") - spark.range(20).write.format(v2Format).mode(SaveMode.Append).saveAsTable("same_name") - checkAnswer(spark.table("same_name"), spark.range(10).toDF()) - checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) - } - } - } - test("saveAsTable: v2 table - table exists") { val t1 = "tbl" val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") @@ -147,19 +184,6 @@ class DataSourceV2DataFrameSessionCatalogSuite } } - test("saveAsTable with mode Overwrite should not fail if the table already exists " + - "and a same-name temp view exist") { - withTable("same_name") { - withTempView("same_name") { - sql(s"CREATE TABLE same_name(id LONG) USING $v2Format") - spark.range(10).createTempView("same_name") - spark.range(20).write.format(v2Format).mode(SaveMode.Overwrite).saveAsTable("same_name") - checkAnswer(spark.table("same_name"), spark.range(10).toDF()) - checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) - } - } - } - test("saveAsTable: v2 table - ignore mode and table doesn't exist") { val t1 = "tbl" val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") @@ -175,55 +199,3 @@ class DataSourceV2DataFrameSessionCatalogSuite verifyTable(t1, Seq(("c", "d")).toDF("id", "data")) } } - -class InMemoryTableProvider extends TableProvider { - override def getTable(options: CaseInsensitiveStringMap): Table = { - throw new UnsupportedOperationException("D'oh!") - } -} - -/** A SessionCatalog that always loads an in memory Table, so we can test write code paths. */ -class TestV2SessionCatalog extends V2SessionCatalog { - - protected val tables: util.Map[Identifier, InMemoryTable] = - new ConcurrentHashMap[Identifier, InMemoryTable]() - - private def fullIdentifier(ident: Identifier): Identifier = { - if (ident.namespace().isEmpty) { - Identifier.of(Array("default"), ident.name()) - } else { - ident - } - } - - override def loadTable(ident: Identifier): Table = { - val fullIdent = fullIdentifier(ident) - if (tables.containsKey(fullIdent)) { - tables.get(fullIdent) - } else { - // Table was created through the built-in catalog - val t = super.loadTable(fullIdent) - val table = new InMemoryTable(t.name(), t.schema(), t.partitioning(), t.properties()) - tables.put(fullIdent, table) - table - } - } - - override def createTable( - ident: Identifier, - schema: StructType, - partitions: Array[Transform], - properties: util.Map[String, String]): Table = { - val created = super.createTable(ident, schema, partitions, properties) - val t = new InMemoryTable(created.name(), schema, partitions, properties) - val fullIdent = fullIdentifier(ident) - tables.put(fullIdent, t) - t - } - - def clearTables(): Unit = { - assert(!tables.isEmpty, "Tables were empty, maybe didn't use the session catalog code path?") - tables.keySet().asScala.foreach(super.dropTable) - tables.clear() - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala index 93889ba1da15..1cb061bddd5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala @@ -234,7 +234,8 @@ class InMemoryTable( private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { - dataMap --= deletesKeys(filters) + val deleteKeys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) + dataMap --= deleteKeys withData(messages.map(_.asInstanceOf[BufferedRows])) } } @@ -247,32 +248,31 @@ class InMemoryTable( } override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { - dataMap --= deletesKeys(filters) - } - - private def splitAnd(filter: Filter): Seq[Filter] = { - filter match { - case And(left, right) => splitAnd(left) ++ splitAnd(right) - case _ => filter :: Nil - } + dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) } +} - private def deletesKeys(filters: Array[Filter]): Iterable[Seq[Any]] = { - dataMap.synchronized { - dataMap.keys.filter { partValues => - filters.flatMap(splitAnd).forall { - case EqualTo(attr, value) => - value == extractValue(attr, partValues) - case IsNotNull(attr) => - null != extractValue(attr, partValues) - case f => - throw new IllegalArgumentException(s"Unsupported filter type: $f") - } +object InMemoryTable { + def filtersToKeys( + keys: Iterable[Seq[Any]], + partitionNames: Seq[String], + filters: Array[Filter]): Iterable[Seq[Any]] = { + keys.filter { partValues => + filters.flatMap(splitAnd).forall { + case EqualTo(attr, value) => + value == extractValue(attr, partitionNames, partValues) + case IsNotNull(attr) => + null != extractValue(attr, partitionNames, partValues) + case f => + throw new IllegalArgumentException(s"Unsupported filter type: $f") } } } - private def extractValue(attr: String, partValues: Seq[Any]): Any = { + private def extractValue( + attr: String, + partFieldNames: Seq[String], + partValues: Seq[Any]): Any = { partFieldNames.zipWithIndex.find(_._1 == attr) match { case Some((_, partIndex)) => partValues(partIndex) @@ -280,6 +280,13 @@ class InMemoryTable( throw new IllegalArgumentException(s"Unknown filter attribute: $attr") } } + + private def splitAnd(filter: Filter): Seq[Filter] = { + filter match { + case And(left, right) => splitAnd(left) ++ splitAnd(right) + case _ => filter :: Nil + } + } } object TestInMemoryTableCatalog { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V1WriteFallbackSuite.scala new file mode 100644 index 000000000000..60e2443d0967 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V1WriteFallbackSuite.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.util + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession} +import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.sources.{DataSourceRegister, Filter, InsertableRelation} +import org.apache.spark.sql.sources.v2.utils.TestV2SessionCatalogBase +import org.apache.spark.sql.sources.v2.writer.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class V1WriteFallbackSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { + + import testImplicits._ + + private val v2Format = classOf[InMemoryV1Provider].getName + + override def beforeAll(): Unit = { + super.beforeAll() + InMemoryV1Provider.clear() + } + + override def afterEach(): Unit = { + super.afterEach() + InMemoryV1Provider.clear() + } + + test("append fallback") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + df.write.mode("append").option("name", "t1").format(v2Format).save() + checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df) + df.write.mode("append").option("name", "t1").format(v2Format).save() + checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df.union(df)) + } + + test("overwrite by truncate fallback") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + df.write.mode("append").option("name", "t1").format(v2Format).save() + + val df2 = Seq((10, "k"), (20, "l"), (30, "m")).toDF("a", "b") + df2.write.mode("overwrite").option("name", "t1").format(v2Format).save() + checkAnswer(InMemoryV1Provider.getTableData(spark, "t1"), df2) + } +} + +class V1WriteFallbackSessionCatalogSuite + extends SessionCatalogTest[InMemoryTableWithV1Fallback, V1FallbackTableCatalog] { + override protected val v2Format = classOf[InMemoryV1Provider].getName + override protected val catalogClassName: String = classOf[V1FallbackTableCatalog].getName + + override protected def verifyTable(tableName: String, expected: DataFrame): Unit = { + checkAnswer(InMemoryV1Provider.getTableData(spark, s"default.$tableName"), expected) + } +} + +class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV1Fallback] { + override def newTable( + name: String, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): InMemoryTableWithV1Fallback = { + val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties) + InMemoryV1Provider.tables.put(name, t) + t + } +} + +private object InMemoryV1Provider { + val tables: mutable.Map[String, InMemoryTableWithV1Fallback] = mutable.Map.empty + + def getTableData(spark: SparkSession, name: String): DataFrame = { + val t = tables.getOrElse(name, throw new IllegalArgumentException(s"Table $name doesn't exist")) + spark.createDataFrame(t.getData.asJava, t.schema) + } + + def clear(): Unit = { + tables.clear() + } +} + +class InMemoryV1Provider extends TableProvider with DataSourceRegister { + override def getTable(options: CaseInsensitiveStringMap): Table = { + InMemoryV1Provider.tables.getOrElseUpdate(options.get("name"), { + new InMemoryTableWithV1Fallback( + "InMemoryTableWithV1Fallback", + new StructType().add("a", IntegerType).add("b", StringType), + Array(IdentityTransform(FieldReference(Seq("a")))), + options.asCaseSensitiveMap() + ) + }) + } + + override def shortName(): String = "in-memory" +} + +class InMemoryTableWithV1Fallback( + override val name: String, + override val schema: StructType, + override val partitioning: Array[Transform], + override val properties: util.Map[String, String]) extends Table with SupportsWrite { + + partitioning.foreach { t => + if (!t.isInstanceOf[IdentityTransform]) { + throw new IllegalArgumentException(s"Transform $t must be IdentityTransform") + } + } + + override def capabilities: util.Set[TableCapability] = Set( + TableCapability.BATCH_WRITE, + TableCapability.V1_BATCH_WRITE, + TableCapability.OVERWRITE_BY_FILTER, + TableCapability.TRUNCATE).asJava + + @volatile private var dataMap: mutable.Map[Seq[Any], Seq[Row]] = mutable.Map.empty + private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) + private val partIndexes = partFieldNames.map(schema.fieldIndex(_)) + + def getData: Seq[Row] = dataMap.values.flatten.toSeq + + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { + new FallbackWriteBuilder(options) + } + + private class FallbackWriteBuilder(options: CaseInsensitiveStringMap) + extends WriteBuilder + with V1WriteBuilder + with SupportsTruncate + with SupportsOverwrite { + + private var mode = "append" + + override def truncate(): WriteBuilder = { + dataMap.clear() + mode = "truncate" + this + } + + override def overwrite(filters: Array[Filter]): WriteBuilder = { + val keys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) + dataMap --= keys + mode = "overwrite" + this + } + + private def getPartitionValues(row: Row): Seq[Any] = { + partIndexes.map(row.get) + } + + override def buildForV1Write(): InsertableRelation = { + new InsertableRelation { + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + assert(!overwrite, "V1 write fallbacks cannot be called with overwrite=true") + val rows = data.collect() + rows.groupBy(getPartitionValues).foreach { case (partition, elements) => + if (dataMap.contains(partition) && mode == "append") { + dataMap.put(partition, dataMap(partition) ++ elements) + } else if (dataMap.contains(partition)) { + throw new IllegalStateException("Partition was not removed properly") + } else { + dataMap.put(partition, elements) + } + } + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/utils/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/utils/TestV2SessionCatalogBase.scala new file mode 100644 index 000000000000..cfacb4908250 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/utils/TestV2SessionCatalogBase.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.utils + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalog.v2.Identifier +import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog +import org.apache.spark.sql.sources.v2.Table +import org.apache.spark.sql.types.StructType + +/** + * A V2SessionCatalog implementation that can be extended to generate arbitrary `Table` definitions + * for testing DDL as well as write operations (through df.write.saveAsTable, df.write.insertInto + * and SQL). + */ +private[v2] trait TestV2SessionCatalogBase[T <: Table] extends V2SessionCatalog { + + protected val tables: util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]() + + protected def newTable( + name: String, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): T + + private def fullIdentifier(ident: Identifier): Identifier = { + if (ident.namespace().isEmpty) { + Identifier.of(Array("default"), ident.name()) + } else { + ident + } + } + + override def loadTable(ident: Identifier): Table = { + val fullIdent = fullIdentifier(ident) + if (tables.containsKey(fullIdent)) { + tables.get(fullIdent) + } else { + // Table was created through the built-in catalog + val t = super.loadTable(fullIdent) + val table = newTable(t.name(), t.schema(), t.partitioning(), t.properties()) + tables.put(fullIdent, table) + table + } + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + val created = super.createTable(ident, schema, partitions, properties) + val t = newTable(created.name(), schema, partitions, properties) + val fullIdent = fullIdentifier(ident) + tables.put(fullIdent, t) + t + } + + def clearTables(): Unit = { + assert(!tables.isEmpty, "Tables were empty, maybe didn't use the session catalog code path?") + tables.keySet().asScala.foreach(super.dropTable) + tables.clear() + } +}