diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java index 9765118a8dbf..5604bccb04ee 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java @@ -89,14 +89,5 @@ public enum TableCapability { /** * Signals that the table accepts input of any schema in a write operation. */ - ACCEPT_ANY_SCHEMA, - - /** - * Signals that the table supports append writes using the V1 InsertableRelation interface. - *

- * Tables that return this capability must create a V1WriteBuilder and may also support additional - * write modes, like {@link #TRUNCATE}, and {@link #OVERWRITE_BY_FILTER}, but cannot support - * {@link #OVERWRITE_DYNAMIC}. - */ - V1_BATCH_WRITE + ACCEPT_ANY_SCHEMA } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 4a7cb7db45de..afc023ac495c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,21 +17,19 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateNamespace, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, RefreshTable, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowNamespaces, ShowTables} -import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, TableCapability} +import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, CreateNamespace, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, RefreshTable, Repartition, ReplaceTable, ReplaceTableAsSelect, SetCatalogAndNamespace, ShowNamespaces, ShowTables} +import org.apache.spark.sql.connector.catalog.StagingTableCatalog import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources -import org.apache.spark.sql.util.CaseInsensitiveStringMap object DataSourceV2Strategy extends Strategy with PredicateHelper { @@ -183,14 +181,13 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { CreateTableExec(catalog, ident, schema, parts, props, ifNotExists) :: Nil case CreateTableAsSelect(catalog, ident, parts, query, props, options, ifNotExists) => - val writeOptions = new CaseInsensitiveStringMap(options.asJava) catalog match { case staging: StagingTableCatalog => AtomicCreateTableAsSelectExec( - staging, ident, parts, query, planLater(query), props, writeOptions, ifNotExists) :: Nil + staging, ident, parts, query, planLater(query), props, options, ifNotExists) :: Nil case _ => CreateTableAsSelectExec( - catalog, ident, parts, query, planLater(query), props, writeOptions, ifNotExists) :: Nil + catalog, ident, parts, query, planLater(query), props, options, ifNotExists) :: Nil } case RefreshTable(catalog, ident) => @@ -205,7 +202,6 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { } case ReplaceTableAsSelect(catalog, ident, parts, query, props, options, orCreate) => - val writeOptions = new CaseInsensitiveStringMap(options.asJava) catalog match { case staging: StagingTableCatalog => AtomicReplaceTableAsSelectExec( @@ -215,7 +211,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { query, planLater(query), props, - writeOptions, + options, orCreate = orCreate) :: Nil case _ => ReplaceTableAsSelectExec( @@ -225,35 +221,10 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { query, planLater(query), props, - writeOptions, + options, orCreate = orCreate) :: Nil } - case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => - r.table.asWritable match { - case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - AppendDataExecV1(v1, writeOptions.asOptions, query) :: Nil - case v2 => - AppendDataExec(v2, writeOptions.asOptions, planLater(query)) :: Nil - } - - case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => - // fail if any filter cannot be converted. correctness depends on removing all matching data. - val filters = splitConjunctivePredicates(deleteExpr).map { - filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse( - throw new AnalysisException(s"Cannot translate expression to source filter: $filter")) - }.toArray - r.table.asWritable match { - case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, query) :: Nil - case v2 => - OverwriteByExpressionExec(v2, filters, writeOptions.asOptions, planLater(query)) :: Nil - } - - case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) => - OverwritePartitionsDynamicExec( - r.table.asWritable, writeOptions.asOptions, planLater(query)) :: Nil - case DeleteFromTable(r: DataSourceV2Relation, condition) => if (condition.exists(SubqueryExpression.hasSubquery)) { throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala index 509a5f7139cc..a6a1a809796d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} -import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.types.BooleanType @@ -33,10 +32,6 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) { private def failAnalysis(msg: String): Unit = throw new AnalysisException(msg) - private def supportsBatchWrite(table: Table): Boolean = { - table.supportsAny(BATCH_WRITE, V1_BATCH_WRITE) - } - override def apply(plan: LogicalPlan): Unit = { plan foreach { case r: DataSourceV2Relation if !r.table.supports(BATCH_READ) => @@ -48,7 +43,7 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) { // TODO: check STREAMING_WRITE capability. It's not doable now because we don't have a // a logical plan for streaming write. - case AppendData(r: DataSourceV2Relation, _, _, _) if !supportsBatchWrite(r.table) => + case AppendData(r: DataSourceV2Relation, _, _, _) if !r.table.supports(BATCH_WRITE) => failAnalysis(s"Table ${r.table.name()} does not support append in batch mode.") case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _) @@ -58,13 +53,13 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) { case OverwriteByExpression(r: DataSourceV2Relation, expr, _, _, _) => expr match { case Literal(true, BooleanType) => - if (!supportsBatchWrite(r.table) || + if (!r.table.supports(BATCH_WRITE) || !r.table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER)) { failAnalysis( s"Table ${r.table.name()} does not support truncate in batch mode.") } case _ => - if (!supportsBatchWrite(r.table) || !r.table.supports(OVERWRITE_BY_FILTER)) { + if (!r.table.supports(BATCH_WRITE) || !r.table.supports(OVERWRITE_BY_FILTER)) { failAnalysis(s"Table ${r.table.name()} does not support " + "overwrite by filter in batch mode.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index bf67e972976b..ecdbedcae0dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -17,19 +17,13 @@ package org.apache.spark.sql.execution.datasources.v2 -import java.util.UUID - -import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.connector.catalog.SupportsWrite -import org.apache.spark.sql.connector.write.{SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.sources.{AlwaysTrue, Filter, InsertableRelation} -import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} +import org.apache.spark.sql.sources.{Filter, InsertableRelation} /** * Physical plan node for append into a v2 table using V1 write interfaces. @@ -37,12 +31,13 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * Rows in the output data set are appended. */ case class AppendDataExecV1( - table: SupportsWrite, - writeOptions: CaseInsensitiveStringMap, - plan: LogicalPlan) extends V1FallbackWriters { + v1Relation: InsertableRelation, + plan: LogicalPlan) extends LeafExecNode with SupportsV1Write { + + override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - writeWithV1(newWriteBuilder().buildForV1Write()) + writeWithV1(v1Relation) } } @@ -58,50 +53,14 @@ case class AppendDataExecV1( * AlwaysTrue to delete all rows. */ case class OverwriteByExpressionExecV1( - table: SupportsWrite, - deleteWhere: Array[Filter], - writeOptions: CaseInsensitiveStringMap, - plan: LogicalPlan) extends V1FallbackWriters { - - private def isTruncate(filters: Array[Filter]): Boolean = { - filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] - } - - override protected def doExecute(): RDD[InternalRow] = { - newWriteBuilder() match { - case builder: SupportsTruncate if isTruncate(deleteWhere) => - writeWithV1(builder.truncate().asV1Builder.buildForV1Write()) - - case builder: SupportsOverwrite => - writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write()) + v1Relation: InsertableRelation, + deleteWhere: Seq[Filter], + plan: LogicalPlan) extends LeafExecNode with SupportsV1Write { - case _ => - throw new SparkException(s"Table does not support overwrite by expression: $table") - } - } -} - -/** Some helper interfaces that use V2 write semantics through the V1 writer interface. */ -sealed trait V1FallbackWriters extends SupportsV1Write { override def output: Seq[Attribute] = Nil - override final def children: Seq[SparkPlan] = Nil - - def table: SupportsWrite - def writeOptions: CaseInsensitiveStringMap - protected implicit class toV1WriteBuilder(builder: WriteBuilder) { - def asV1Builder: V1WriteBuilder = builder match { - case v1: V1WriteBuilder => v1 - case other => throw new IllegalStateException( - s"The returned writer ${other} was no longer a V1WriteBuilder.") - } - } - - protected def newWriteBuilder(): V1WriteBuilder = { - val writeBuilder = table.newWriteBuilder(writeOptions) - .withInputDataSchema(plan.schema) - .withQueryId(UUID.randomUUID().toString) - writeBuilder.asV1Builder + override protected def doExecute(): RDD[InternalRow] = { + writeWithV1(v1Relation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteStrategy.scala new file mode 100644 index 000000000000..a860c686da1c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteStrategy.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.UUID + +import org.apache.spark.sql.{AnalysisException, Strategy} +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.write.{SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.{AlwaysTrue, Filter} +import org.apache.spark.sql.types.StructType + +object V2WriteStrategy extends Strategy with PredicateHelper { + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case WriteToDataSourceV2(writer, query) => + WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + + case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => + val writeBuilder = newWriteBuilder(r.table, writeOptions, query.schema) + writeBuilder match { + case v1: V1WriteBuilder => + AppendDataExecV1(v1.buildForV1Write(), query) :: Nil + case _ => + AppendDataExec(writeBuilder.buildForBatch(), planLater(query)) :: Nil + } + + case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => + // fail if any filter cannot be converted. correctness depends on removing all matching data. + val filters = splitConjunctivePredicates(deleteExpr).map { + filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse( + throw new AnalysisException(s"Cannot translate expression to source filter: $filter")) + }.toArray + + val writeBuilder = newWriteBuilder(r.table, writeOptions, query.schema) + val configured = writeBuilder match { + case builder: SupportsTruncate if isTruncate(filters) => builder.truncate() + case builder: SupportsOverwrite => builder.overwrite(filters) + case _ => + throw new IllegalArgumentException( + s"Table does not support overwrite by expression: ${r.table.name}") + } + + configured match { + case v1: V1WriteBuilder => + OverwriteByExpressionExecV1(v1.buildForV1Write(), filters, query) :: Nil + case _ => + OverwriteByExpressionExec(configured.buildForBatch(), filters, planLater(query)) :: Nil + } + + case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) => + val writeBuilder = newWriteBuilder(r.table, writeOptions, query.schema) + val configured = writeBuilder match { + case builder: SupportsDynamicOverwrite => + builder.overwriteDynamicPartitions() + case _ => + throw new IllegalArgumentException( + s"Table does not support dynamic partition overwrite: ${r.table.name}") + } + OverwritePartitionsDynamicExec(configured.buildForBatch(), planLater(query)) :: Nil + } + + def newWriteBuilder( + table: Table, + options: Map[String, String], + inputSchema: StructType): WriteBuilder = { + table.asWritable.newWriteBuilder(options.asOptions) + .withInputDataSchema(inputSchema) + .withQueryId(UUID.randomUUID().toString) + } + + private def isTruncate(filters: Array[Filter]): Boolean = { + filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 9f4392da6ab4..c2328d99006c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, SupportsWrite, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, V1WriteBuilder, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.sources.{AlwaysTrue, Filter} +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LongAccumulator, Utils} @@ -65,10 +65,8 @@ case class CreateTableAsSelectExec( plan: LogicalPlan, query: SparkPlan, properties: Map[String, String], - writeOptions: CaseInsensitiveStringMap, - ifNotExists: Boolean) extends V2TableWriteExec with SupportsV1Write { - - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper + writeOptions: Map[String, String], + ifNotExists: Boolean) extends V2TableWriteWithV1FallBack { override protected def doExecute(): RDD[InternalRow] = { if (catalog.tableExists(ident)) { @@ -81,23 +79,9 @@ case class CreateTableAsSelectExec( Utils.tryWithSafeFinallyAndFailureCallbacks({ val schema = query.schema.asNullable - catalog.createTable( - ident, schema, partitioning.toArray, properties.asJava) match { - case table: SupportsWrite => - val writeBuilder = table.newWriteBuilder(writeOptions) - .withInputDataSchema(schema) - .withQueryId(UUID.randomUUID().toString) - - writeBuilder match { - case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write()) - case v2 => writeWithV2(v2.buildForBatch()) - } - - case _ => - // table does not support writes - throw new SparkException( - s"Table implementation does not support writes: ${ident.quoted}") - } + val createdTable = catalog.createTable(ident, schema, partitioning.toArray, properties.asJava) + val writeBuilder = V2WriteStrategy.newWriteBuilder(createdTable, writeOptions, schema) + writeWithV1Fallback(writeBuilder) })(catchBlock = { catalog.dropTable(ident) }) @@ -120,7 +104,7 @@ case class AtomicCreateTableAsSelectExec( plan: LogicalPlan, query: SparkPlan, properties: Map[String, String], - writeOptions: CaseInsensitiveStringMap, + writeOptions: Map[String, String], ifNotExists: Boolean) extends AtomicTableWriteExec { override protected def doExecute(): RDD[InternalRow] = { @@ -154,10 +138,8 @@ case class ReplaceTableAsSelectExec( plan: LogicalPlan, query: SparkPlan, properties: Map[String, String], - writeOptions: CaseInsensitiveStringMap, - orCreate: Boolean) extends V2TableWriteExec with SupportsV1Write { - - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper + writeOptions: Map[String, String], + orCreate: Boolean) extends V2TableWriteWithV1FallBack { override protected def doExecute(): RDD[InternalRow] = { // Note that this operation is potentially unsafe, but these are the strict semantics of @@ -177,22 +159,8 @@ case class ReplaceTableAsSelectExec( val createdTable = catalog.createTable( ident, schema, partitioning.toArray, properties.asJava) Utils.tryWithSafeFinallyAndFailureCallbacks({ - createdTable match { - case table: SupportsWrite => - val writeBuilder = table.newWriteBuilder(writeOptions) - .withInputDataSchema(schema) - .withQueryId(UUID.randomUUID().toString) - - writeBuilder match { - case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write()) - case v2 => writeWithV2(v2.buildForBatch()) - } - - case _ => - // table does not support writes - throw new SparkException( - s"Table implementation does not support writes: ${ident.quoted}") - } + val writeBuilder = V2WriteStrategy.newWriteBuilder(createdTable, writeOptions, schema) + writeWithV1Fallback(writeBuilder) })(catchBlock = { catalog.dropTable(ident) }) @@ -218,7 +186,7 @@ case class AtomicReplaceTableAsSelectExec( plan: LogicalPlan, query: SparkPlan, properties: Map[String, String], - writeOptions: CaseInsensitiveStringMap, + writeOptions: Map[String, String], orCreate: Boolean) extends AtomicTableWriteExec { override protected def doExecute(): RDD[InternalRow] = { @@ -246,13 +214,10 @@ case class AtomicReplaceTableAsSelectExec( * * Rows in the output data set are appended. */ -case class AppendDataExec( - table: SupportsWrite, - writeOptions: CaseInsensitiveStringMap, - query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { +case class AppendDataExec(write: BatchWrite, query: SparkPlan) extends V2TableWriteExec { override protected def doExecute(): RDD[InternalRow] = { - writeWithV2(newWriteBuilder().buildForBatch()) + writeWithV2(write) } } @@ -267,26 +232,12 @@ case class AppendDataExec( * AlwaysTrue to delete all rows. */ case class OverwriteByExpressionExec( - table: SupportsWrite, - deleteWhere: Array[Filter], - writeOptions: CaseInsensitiveStringMap, - query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { - - private def isTruncate(filters: Array[Filter]): Boolean = { - filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] - } + write: BatchWrite, + deleteWhere: Seq[Filter], + query: SparkPlan) extends V2TableWriteExec { override protected def doExecute(): RDD[InternalRow] = { - newWriteBuilder() match { - case builder: SupportsTruncate if isTruncate(deleteWhere) => - writeWithV2(builder.truncate().buildForBatch()) - - case builder: SupportsOverwrite => - writeWithV2(builder.overwrite(deleteWhere).buildForBatch()) - - case _ => - throw new SparkException(s"Table does not support overwrite by expression: $table") - } + writeWithV2(write) } } @@ -300,18 +251,11 @@ case class OverwriteByExpressionExec( * are not modified. */ case class OverwritePartitionsDynamicExec( - table: SupportsWrite, - writeOptions: CaseInsensitiveStringMap, - query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + write: BatchWrite, + query: SparkPlan) extends V2TableWriteExec { override protected def doExecute(): RDD[InternalRow] = { - newWriteBuilder() match { - case builder: SupportsDynamicOverwrite => - writeWithV2(builder.overwriteDynamicPartitions().buildForBatch()) - - case _ => - throw new SparkException(s"Table does not support dynamic partition overwrite: $table") - } + writeWithV2(write) } } @@ -319,8 +263,6 @@ case class WriteToDataSourceV2Exec( batchWrite: BatchWrite, query: SparkPlan) extends V2TableWriteExec { - def writeOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty() - override protected def doExecute(): RDD[InternalRow] = { writeWithV2(batchWrite) } @@ -410,6 +352,16 @@ trait V2TableWriteExec extends UnaryExecNode { } } +trait V2TableWriteWithV1FallBack extends V2TableWriteExec with SupportsV1Write { + + protected def writeWithV1Fallback(builder: WriteBuilder): RDD[InternalRow] = { + builder match { + case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write()) + case v2 => writeWithV2(v2.buildForBatch()) + } + } +} + object DataWritingSparkTask extends Logging { def run( writerFactory: DataWriterFactory, @@ -468,32 +420,16 @@ object DataWritingSparkTask extends Logging { } } -private[v2] trait AtomicTableWriteExec extends V2TableWriteExec with SupportsV1Write { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper - +private[v2] trait AtomicTableWriteExec extends V2TableWriteWithV1FallBack { protected def writeToStagedTable( stagedTable: StagedTable, - writeOptions: CaseInsensitiveStringMap, + writeOptions: Map[String, String], ident: Identifier): RDD[InternalRow] = { Utils.tryWithSafeFinallyAndFailureCallbacks({ - stagedTable match { - case table: SupportsWrite => - val writeBuilder = table.newWriteBuilder(writeOptions) - .withInputDataSchema(query.schema) - .withQueryId(UUID.randomUUID().toString) - - val writtenRows = writeBuilder match { - case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write()) - case v2 => writeWithV2(v2.buildForBatch()) - } - stagedTable.commitStagedChanges() - writtenRows - - case _ => - // Table does not support writes - staged changes are also rolled back below. - throw new SparkException( - s"Table implementation does not support writes: ${ident.quoted}") - } + val writeBuilder = V2WriteStrategy.newWriteBuilder(stagedTable, writeOptions, query.schema) + val res = writeWithV1Fallback(writeBuilder) + stagedTable.commitStagedChanges() + res })(catchBlock = { // Failure rolls back the staged writes and metadata changes. stagedTable.abortStagedChanges() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index ce6d56cf84df..39f4085a9baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -98,19 +98,16 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { } test("AppendData: check correct capabilities") { - Seq(BATCH_WRITE, V1_BATCH_WRITE).foreach { write => - val plan = AppendData.byName( - DataSourceV2Relation.create(CapabilityTable(write), CaseInsensitiveStringMap.empty), - TestRelation) + val plan = AppendData.byName( + DataSourceV2Relation.create(CapabilityTable(BATCH_WRITE), CaseInsensitiveStringMap.empty), + TestRelation) - TableCapabilityCheck.apply(plan) - } + TableCapabilityCheck.apply(plan) } test("Truncate: check missing capabilities") { Seq(CapabilityTable(), CapabilityTable(BATCH_WRITE), - CapabilityTable(V1_BATCH_WRITE), CapabilityTable(TRUNCATE), CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => @@ -128,9 +125,7 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { test("Truncate: check correct capabilities") { Seq(CapabilityTable(BATCH_WRITE, TRUNCATE), - CapabilityTable(V1_BATCH_WRITE, TRUNCATE), - CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER), - CapabilityTable(V1_BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table => + CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table => val plan = OverwriteByExpression.byName( DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, @@ -142,7 +137,6 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { test("OverwriteByExpression: check missing capabilities") { Seq(CapabilityTable(), - CapabilityTable(V1_BATCH_WRITE), CapabilityTable(BATCH_WRITE), CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => @@ -159,14 +153,12 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { } test("OverwriteByExpression: check correct capabilities") { - Seq(BATCH_WRITE, V1_BATCH_WRITE).foreach { write => - val table = CapabilityTable(write, OVERWRITE_BY_FILTER) - val plan = OverwriteByExpression.byName( - DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, - EqualTo(AttributeReference("x", LongType)(), Literal(5))) + val table = CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER) + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + EqualTo(AttributeReference("x", LongType)(), Literal(5))) - TableCapabilityCheck.apply(plan) - } + TableCapabilityCheck.apply(plan) } test("OverwritePartitionsDynamic: check missing capabilities") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index de843ba4375d..93ef0df3e325 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -257,7 +257,6 @@ class InMemoryTableWithV1Fallback( } override def capabilities: util.Set[TableCapability] = Set( - TableCapability.V1_BATCH_WRITE, TableCapability.OVERWRITE_BY_FILTER, TableCapability.TRUNCATE).asJava