diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 19ab5ada2b5c4..6deaaceb5ca8a 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -3674,7 +3674,7 @@ }, "_LEGACY_ERROR_TEMP_2054" : { "message" : [ - "Task failed while writing rows." + "Task failed while writing rows. " ] }, "_LEGACY_ERROR_TEMP_2055" : { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 15dfa581c5976..cef4acafe07c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -785,7 +785,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { def taskFailedWhileWritingRowsError(cause: Throwable): Throwable = { new SparkException( errorClass = "_LEGACY_ERROR_TEMP_2054", - messageParameters = Map.empty, + messageParameters = Map("message" -> cause.getMessage), cause = cause) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/internal/WriteSpec.java b/sql/core/src/main/java/org/apache/spark/sql/internal/WriteSpec.java new file mode 100644 index 0000000000000..c51a3ed7dc6b2 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/internal/WriteSpec.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal; + +import java.io.Serializable; + +/** + * Write spec is a input parameter of + * {@link org.apache.spark.sql.execution.SparkPlan#executeWrite}. + * + *

+ * This is an empty interface, the concrete class which implements + * {@link org.apache.spark.sql.execution.SparkPlan#doExecuteWrite} + * should define its own class and use it. + * + * @since 3.4.0 + */ +public interface WriteSpec extends Serializable {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 4aca67a17cdeb..401302e5bdea2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -34,9 +34,10 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, UnaryLike} +import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SQLConf, WriteSpec} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -223,6 +224,19 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ doExecuteColumnar() } + /** + * Returns the result of writes as an RDD[WriterCommitMessage] variable by delegating to + * `doExecuteWrite` after preparations. + * + * Concrete implementations of SparkPlan should override `doExecuteWrite`. + */ + def executeWrite(writeSpec: WriteSpec): RDD[WriterCommitMessage] = executeQuery { + if (isCanonicalizedPlan) { + throw SparkException.internalError("A canonicalized plan is not supposed to be executed.") + } + doExecuteWrite(writeSpec) + } + /** * Executes a query after preparing the query and adding query plan information to created RDDs * for visualization. @@ -324,6 +338,16 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ s" mismatch:\n${this}") } + /** + * Produces the result of the writes as an `RDD[WriterCommitMessage]` + * + * Overridden by concrete implementations of SparkPlan. + */ + protected def doExecuteWrite(writeSpec: WriteSpec): RDD[WriterCommitMessage] = { + throw SparkException.internalError(s"Internal Error ${this.getClass} has write support" + + s" mismatch:\n${this}") + } + /** * Converts the output of this plan to row-based if it is columnar plan. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b96e47846fc93..51a0c837c3e94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors import org.apache.spark.sql.execution.aggregate.AggUtils import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.datasources.{WriteFiles, WriteFilesExec} import org.apache.spark.sql.execution.exchange.{REBALANCE_PARTITIONS_BY_COL, REBALANCE_PARTITIONS_BY_NONE, REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeExec} import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ @@ -894,6 +895,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("MERGE INTO TABLE") case logical.CollectMetrics(name, metrics, child) => execution.CollectMetricsExec(name, metrics, planLater(child)) :: Nil + case WriteFiles(child) => + WriteFilesExec(planLater(child)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 95e1a159ef84f..9bf9f43829e29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -145,6 +145,12 @@ case class CreateDataSourceTableAsSelectCommand( outputColumnNames: Seq[String]) extends V1WriteCommand { + override def fileFormatProvider: Boolean = { + table.provider.forall { provider => + classOf[FileFormat].isAssignableFrom(DataSource.providingClass(provider, conf)) + } + } + override lazy val partitionColumns: Seq[Attribute] = { val unresolvedPartitionColumns = table.partitionColumnNames.map(UnresolvedAttribute.quoted) DataSource.resolvePartitionColumns( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b2bc7301ade87..3d8eb9bc8a8bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -97,19 +97,8 @@ case class DataSource( case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String]) - lazy val providingClass: Class[_] = { - val cls = DataSource.lookupDataSource(className, sparkSession.sessionState.conf) - // `providingClass` is used for resolving data source relation for catalog tables. - // As now catalog for data source V2 is under development, here we fall back all the - // [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog works. - // [[FileDataSourceV2]] will still be used if we call the load()/save() method in - // [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource` - // instead of `providingClass`. - cls.newInstance() match { - case f: FileDataSourceV2 => f.fallbackFileFormat - case _ => cls - } - } + lazy val providingClass: Class[_] = + DataSource.providingClass(className, sparkSession.sessionState.conf) private[sql] def providingInstance(): Any = providingClass.getConstructor().newInstance() @@ -843,4 +832,18 @@ object DataSource extends Logging { } } } + + def providingClass(className: String, conf: SQLConf): Class[_] = { + val cls = DataSource.lookupDataSource(className, conf) + // `providingClass` is used for resolving data source relation for catalog tables. + // As now catalog for data source V2 is under development, here we fall back all the + // [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog works. + // [[FileDataSourceV2]] will still be used if we call the load()/save() method in + // [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource` + // instead of `providingClass`. + cls.newInstance() match { + case f: FileDataSourceV2 => f.fallbackFileFormat + case _ => cls + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 43b18a3b2d10c..2e082f3febf7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter} import org.apache.spark.sql.internal.SQLConf @@ -103,14 +104,6 @@ object FileFormatWriter extends Logging { .map(FileSourceMetadataAttribute.cleanupFileSourceMetadataInformation)) val dataColumns = finalOutputSpec.outputColumns.filterNot(partitionSet.contains) - val hasEmpty2Null = plan.exists(p => V1WritesUtils.hasEmptyToNull(p.expressions)) - val empty2NullPlan = if (hasEmpty2Null) { - plan - } else { - val projectList = V1WritesUtils.convertEmptyToNull(plan.output, partitionColumns) - if (projectList.nonEmpty) ProjectExec(projectList, plan) else plan - } - val writerBucketSpec = V1WritesUtils.getWriterBucketSpec(bucketSpec, dataColumns, options) val sortColumns = V1WritesUtils.getBucketSortColumns(bucketSpec, dataColumns) @@ -144,9 +137,10 @@ object FileFormatWriter extends Logging { // columns. val requiredOrdering = partitionColumns.drop(numStaticPartitionCols) ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns + val writeFilesOpt = V1WritesUtils.getWriteFilesOpt(plan) // the sort order doesn't matter // Use the output ordering from the original plan before adding the empty2null projection. - val actualOrdering = plan.outputOrdering.map(_.child) + val actualOrdering = writeFilesOpt.map(_.child).getOrElse(plan).outputOrdering.map(_.child) val orderingMatched = V1WritesUtils.isOrderingMatched(requiredOrdering, actualOrdering) SQLExecution.checkSQLExecutionId(sparkSession) @@ -155,10 +149,6 @@ object FileFormatWriter extends Logging { // get an ID guaranteed to be unique. job.getConfiguration.set("spark.sql.sources.writeJobUUID", description.uuid) - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - committer.setupJob(job) - // When `PLANNED_WRITE_ENABLED` is true, the optimizer rule V1Writes will add logical sort // operator based on the required ordering of the V1 write command. So the output // ordering of the physical plan should always match the required ordering. Here @@ -169,27 +159,55 @@ object FileFormatWriter extends Logging { // V1 write command will be empty). if (Utils.isTesting) outputOrderingMatched = orderingMatched - try { + if (writeFilesOpt.isDefined) { + // build `WriteFilesSpec` for `WriteFiles` + val concurrentOutputWriterSpecFunc = (plan: SparkPlan) => { + val sortPlan = createSortPlan(plan, requiredOrdering, outputSpec) + createConcurrentOutputWriterSpec(sparkSession, sortPlan, sortColumns) + } + val writeSpec = WriteFilesSpec( + description = description, + committer = committer, + concurrentOutputWriterSpecFunc = concurrentOutputWriterSpecFunc + ) + executeWrite(sparkSession, plan, writeSpec, job) + } else { + executeWrite(sparkSession, plan, job, description, committer, outputSpec, + requiredOrdering, partitionColumns, sortColumns, orderingMatched) + } + } + // scalastyle:on argcount + + private def executeWrite( + sparkSession: SparkSession, + plan: SparkPlan, + job: Job, + description: WriteJobDescription, + committer: FileCommitProtocol, + outputSpec: OutputSpec, + requiredOrdering: Seq[Expression], + partitionColumns: Seq[Attribute], + sortColumns: Seq[Attribute], + orderingMatched: Boolean): Set[String] = { + val hasEmpty2Null = plan.exists(p => V1WritesUtils.hasEmptyToNull(p.expressions)) + val empty2NullPlan = if (hasEmpty2Null) { + plan + } else { + val projectList = V1WritesUtils.convertEmptyToNull(plan.output, partitionColumns) + if (projectList.nonEmpty) ProjectExec(projectList, plan) else plan + } + + writeAndCommit(job, description, committer) { val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) { (empty2NullPlan.execute(), None) } else { - // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and - // the physical plan may have different attribute ids due to optimizer removing some - // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. - val orderingExpr = bindReferences( - requiredOrdering.map(SortOrder(_, Ascending)), finalOutputSpec.outputColumns) - val sortPlan = SortExec( - orderingExpr, - global = false, - child = empty2NullPlan) - - val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters - val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty - if (concurrentWritersEnabled) { - (empty2NullPlan.execute(), - Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter()))) + val sortPlan = createSortPlan(empty2NullPlan, requiredOrdering, outputSpec) + val concurrentOutputWriterSpec = createConcurrentOutputWriterSpec( + sparkSession, sortPlan, sortColumns) + if (concurrentOutputWriterSpec.isDefined) { + (empty2NullPlan.execute(), concurrentOutputWriterSpec) } else { - (sortPlan.execute(), None) + (sortPlan.execute(), concurrentOutputWriterSpec) } } @@ -221,7 +239,19 @@ object FileFormatWriter extends Logging { committer.onTaskCommit(res.commitMsg) ret(index) = res }) + ret + } + } + private def writeAndCommit( + job: Job, + description: WriteJobDescription, + committer: FileCommitProtocol)(f: => Array[WriteTaskResult]): Set[String] = { + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + committer.setupJob(job) + try { + val ret = f val commitMsgs = ret.map(_.commitMsg) logInfo(s"Start to commit write Job ${description.uuid}.") @@ -239,10 +269,70 @@ object FileFormatWriter extends Logging { throw cause } } - // scalastyle:on argcount + + /** + * Write files using [[SparkPlan.executeWrite]] + */ + private def executeWrite( + session: SparkSession, + planForWrites: SparkPlan, + writeFilesSpec: WriteFilesSpec, + job: Job): Set[String] = { + val committer = writeFilesSpec.committer + val description = writeFilesSpec.description + + writeAndCommit(job, description, committer) { + val rdd = planForWrites.executeWrite(writeFilesSpec) + val ret = new Array[WriteTaskResult](rdd.partitions.length) + session.sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[WriterCommitMessage]) => { + assert(iter.hasNext) + val commitMessage = iter.next() + assert(!iter.hasNext) + commitMessage + }, + rdd.partitions.indices, + (index, res: WriterCommitMessage) => { + assert(res.isInstanceOf[WriteTaskResult]) + val writeTaskResult = res.asInstanceOf[WriteTaskResult] + committer.onTaskCommit(writeTaskResult.commitMsg) + ret(index) = writeTaskResult + }) + ret + } + } + + private def createSortPlan( + plan: SparkPlan, + requiredOrdering: Seq[Expression], + outputSpec: OutputSpec): SortExec = { + // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and + // the physical plan may have different attribute ids due to optimizer removing some + // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. + val orderingExpr = bindReferences( + requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns) + SortExec( + orderingExpr, + global = false, + child = plan) + } + + private def createConcurrentOutputWriterSpec( + sparkSession: SparkSession, + sortPlan: SortExec, + sortColumns: Seq[Attribute]): Option[ConcurrentOutputWriterSpec] = { + val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters + val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty + if (concurrentWritersEnabled) { + Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter())) + } else { + None + } + } /** Writes data out in a single Spark task. */ - private def executeTask( + private[spark] def executeTask( description: WriteJobDescription, jobIdInstant: Long, sparkStageId: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala index d082b95739cba..e9f6e3df7853a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala @@ -24,12 +24,17 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String trait V1WriteCommand extends DataWritingCommand { + /** + * Return if the provider is [[FileFormat]] + */ + def fileFormatProvider: Boolean = true /** * Specify the partition columns of the V1 write command. @@ -44,7 +49,7 @@ trait V1WriteCommand extends DataWritingCommand { } /** - * A rule that adds logical sorts to V1 data writing commands. + * A rule that plans v1 write for [[V1WriteCommand]]. */ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper { @@ -52,11 +57,13 @@ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper { override def apply(plan: LogicalPlan): LogicalPlan = { if (conf.plannedWriteEnabled) { - plan.transformDown { - case write: V1WriteCommand => + plan.transformUp { + case write: V1WriteCommand if write.fileFormatProvider && + !write.child.isInstanceOf[WriteFiles] => val newQuery = prepareQuery(write, write.query) val attrMap = AttributeMap(write.query.output.zip(newQuery.output)) - val newWrite = write.withNewChildren(newQuery :: Nil).transformExpressions { + val newChild = WriteFiles(newQuery) + val newWrite = write.withNewChildren(newChild :: Nil).transformExpressions { case a: Attribute if attrMap.contains(a) => a.withExprId(attrMap(a).exprId) } @@ -212,4 +219,10 @@ object V1WritesUtils { } } } + + def getWriteFilesOpt(child: SparkPlan): Option[WriteFilesExec] = { + child.collectFirst { + case w: WriteFilesExec => w + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteFiles.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteFiles.scala new file mode 100644 index 0000000000000..39b7b252f6ea8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteFiles.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.Date + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} +import org.apache.spark.sql.connector.write.WriterCommitMessage +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec +import org.apache.spark.sql.internal.WriteSpec + +/** + * The write files spec holds all information of [[V1WriteCommand]] if its provider is + * [[FileFormat]]. + */ +case class WriteFilesSpec( + description: WriteJobDescription, + committer: FileCommitProtocol, + concurrentOutputWriterSpecFunc: SparkPlan => Option[ConcurrentOutputWriterSpec]) + extends WriteSpec + +/** + * During Optimizer, [[V1Writes]] injects the [[WriteFiles]] between [[V1WriteCommand]] and query. + * [[WriteFiles]] must be the root plan as the child of [[V1WriteCommand]]. + */ +case class WriteFiles(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: LogicalPlan): WriteFiles = + copy(child = newChild) +} + +/** + * Responsible for writing files. + */ +case class WriteFilesExec(child: SparkPlan) extends UnaryExecNode { + override def output: Seq[Attribute] = Seq.empty + + override protected def doExecuteWrite(writeSpec: WriteSpec): RDD[WriterCommitMessage] = { + assert(writeSpec.isInstanceOf[WriteFilesSpec]) + val writeFilesSpec: WriteFilesSpec = writeSpec.asInstanceOf[WriteFilesSpec] + + val rdd = child.execute() + // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single + // partition rdd to make sure we at least set up one write task to write the metadata. + val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) { + session.sparkContext.parallelize(Array.empty[InternalRow], 1) + } else { + rdd + } + + val concurrentOutputWriterSpec = writeFilesSpec.concurrentOutputWriterSpecFunc(child) + val description = writeFilesSpec.description + val committer = writeFilesSpec.committer + val jobIdInstant = new Date().getTime + rddWithNonEmptyPartitions.mapPartitionsInternal { iterator => + val sparkStageId = TaskContext.get().stageId() + val sparkPartitionId = TaskContext.get().partitionId() + val sparkAttemptNumber = TaskContext.get().taskAttemptId().toInt & Int.MaxValue + + val ret = FileFormatWriter.executeTask( + description, + jobIdInstant, + sparkStageId, + sparkPartitionId, + sparkAttemptNumber, + committer, + iterator, + concurrentOutputWriterSpec + ) + + Iterator(ret) + } + } + + override protected def doExecute(): RDD[InternalRow] = { + throw SparkException.internalError(s"$nodeName does not support doExecute") + } + + override protected def stringArgs: Iterator[Any] = Iterator(child) + + override protected def withNewChildInternal(newChild: SparkPlan): WriteFilesExec = + copy(child = newChild) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala index d66f2bd0cc423..eb2aa09e075bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala @@ -65,7 +65,12 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { qe.optimizedPlan match { case w: V1WriteCommand => - optimizedPlan = w.query + if (hasLogicalSort) { + assert(w.query.isInstanceOf[WriteFiles]) + optimizedPlan = w.query.asInstanceOf[WriteFiles].child + } else { + optimizedPlan = w.query + } case _ => } }