diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 5e9e6ff1a569..cb3ca1b6c4be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -41,17 +41,11 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration private[libsvm] class LibSVMOutputWriter( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { - override val path: String = { - val compressionExtension = TextOutputWriter.getCompressionExtension(context) - new Path(stagingDir, fileNamePrefix + ".libsvm" + compressionExtension).toString - } - private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { @@ -135,11 +129,14 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour dataSchema: StructType): OutputWriterFactory = { new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new LibSVMOutputWriter(stagingDir, fileNamePrefix, dataSchema, context) + new LibSVMOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".libsvm" + TextOutputWriter.getCompressionExtension(context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala new file mode 100644 index 000000000000..322504cdc9a5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.{Date, UUID} + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.SparkHadoopWriter +import org.apache.spark.internal.Logging +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + + +object FileCommitProtocol { + class TaskCommitMessage(obj: Any) extends Serializable + + object EmptyTaskCommitMessage extends TaskCommitMessage(Unit) + + /** + * Instantiates a FileCommitProtocol using the given className. + */ + def instantiate(className: String, outputPath: String, isAppend: Boolean): FileCommitProtocol = { + try { + val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] + + // First try the one with argument (outputPath: String, isAppend: Boolean). + // If that doesn't exist, try the one with (outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[Boolean]) + ctor.newInstance(outputPath, isAppend.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String]) + ctor.newInstance(outputPath) + } + } catch { + case e: ClassNotFoundException => + throw e + } + } +} + + +/** + * An interface to define how a Spark job commits its outputs. Implementations must be serializable. + * + * The proper call sequence is: + * + * 1. Driver calls setupJob. + * 2. As part of each task's execution, executor calls setupTask and then commitTask + * (or abortTask if task failed). + * 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job + * failed to execute (e.g. too many failed tasks), the job should call abortJob. + */ +abstract class FileCommitProtocol { + import FileCommitProtocol._ + + /** + * Setups up a job. Must be called on the driver before any other methods can be invoked. + */ + def setupJob(jobContext: JobContext): Unit + + /** + * Commits a job after the writes succeed. Must be called on the driver. + */ + def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit + + /** + * Aborts a job after the writes fail. Must be called on the driver. + * + * Calling this function is a best-effort attempt, because it is possible that the driver + * just crashes (or killed) before it can call abort. + */ + def abortJob(jobContext: JobContext): Unit + + /** + * Sets up a task within a job. + * Must be called before any other task related methods can be invoked. + */ + def setupTask(taskContext: TaskAttemptContext): Unit + + /** + * Notifies the commit protocol to add a new file, and gets back the full path that should be + * used. Must be called on the executors when running tasks. + * + * A full file path consists of the following parts: + * 1. the base path + * 2. some sub-directory within the base path, used to specify partitioning + * 3. file prefix, usually some unique job id with the task id + * 4. bucket id + * 5. source specific file extension, e.g. ".snappy.parquet" + * + * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest + * are left to the commit protocol implementation to decide. + */ + def addTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + + /** + * Commits a task after the writes succeed. Must be called on the executors when running tasks. + */ + def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage + + /** + * Aborts a task after the writes have failed. Must be called on the executors when running tasks. + */ + def abortTask(taskContext: TaskAttemptContext): Unit +} + + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the newer mapreduce API, not the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class MapReduceFileCommitterProtocol(path: String, isAppend: Boolean) + extends FileCommitProtocol with Serializable with Logging { + + import FileCommitProtocol._ + + /** OutputCommitter from Hadoop is not serializable so marking it transient. */ + @transient private var committer: OutputCommitter = _ + + /** UUID used to identify the job in file name. */ + private val uuid: String = UUID.randomUUID().toString + + private def setupCommitter(context: TaskAttemptContext): Unit = { + committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) + + if (!isAppend) { + // If we are appending data to an existing dir, we will only use the output committer + // associated with the file output format since it is not safe to use a custom + // committer for appending. For example, in S3, direct parquet output committer may + // leave partial data in the destination dir when the appending job fails. + // See SPARK-8578 for more details. + val configuration = context.getConfiguration + val clazz = + configuration.getClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) + + if (clazz != null) { + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat + // has an associated output committer. To override this output committer, + // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. + // If a data source needs to override the output committer, it needs to set the + // output committer in prepareForWrite method. + if (classOf[FileOutputCommitter].isAssignableFrom(clazz)) { + // The specified output committer is a FileOutputCommitter. + // So, we will use the FileOutputCommitter-specified constructor. + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + committer = ctor.newInstance(new Path(path), context) + } else { + // The specified output committer is just an OutputCommitter. + // So, we will use the no-argument constructor. + val ctor = clazz.getDeclaredConstructor() + committer = ctor.newInstance() + } + } + } + logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}") + } + + override def addTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + val filename = f"part-$split%05d-$uuid$ext" + + val stagingDir: String = committer match { + // For FileOutputCommitter it has its own staging path called "work path". + case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path) + case _ => path + } + + dir.map { d => + new Path(new Path(stagingDir, d), filename).toString + }.getOrElse { + new Path(stagingDir, filename).toString + } + } + + override def setupJob(jobContext: JobContext): Unit = { + // Setup IDs + val jobId = SparkHadoopWriter.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + jobContext.getConfiguration.set("mapred.job.id", jobId.toString) + jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapred.task.is.map", true) + jobContext.getConfiguration.setInt("mapred.task.partition", 0) + + val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) + setupCommitter(taskAttemptContext) + + committer.setupJob(jobContext) + } + + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + committer.commitJob(jobContext) + } + + override def abortJob(jobContext: JobContext): Unit = { + committer.abortJob(jobContext, JobStatus.State.FAILED) + } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { + setupCommitter(taskContext) + committer.setupTask(taskContext) + } + + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { + val attemptId = taskContext.getTaskAttemptID + SparkHadoopMapRedUtil.commitTask( + committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) + EmptyTaskCommitMessage + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { + committer.abortTask(taskContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala index fbf6e96d3f85..a73c8146c1b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala @@ -30,28 +30,21 @@ import org.apache.spark.sql.types.StructType * to executor side to create actual [[OutputWriter]]s on the fly. */ abstract class OutputWriterFactory extends Serializable { + + /** Returns the file extension to be used when writing files out. */ + def getFileExtension(context: TaskAttemptContext): String + /** * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side * to instantiate new [[OutputWriter]]s. * - * @param stagingDir Base path (directory) of the file to which this [[OutputWriter]] is supposed - * to write. Note that this may not point to the final output file. For - * example, `FileOutputFormat` writes to temporary directories and then merge - * written files back to the final destination. In this case, `path` points to - * a temporary output file under the temporary directory. - * @param fileNamePrefix Prefix of the file name. The returned OutputWriter must make sure this - * prefix is used in the actual file name. For example, if the prefix is - * "part-1-2-3", then the file name must start with "part_1_2_3" but can - * end in arbitrary extension that is deterministic given the configuration - * (i.e. the suffix extension should not depend on any task id, attempt id, - * or partition id). + * @param path Path to write the file. * @param dataSchema Schema of the rows to be written. Partition columns are not included in the * schema if the relation being written is partitioned. * @param context The Hadoop MapReduce task context. */ def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter @@ -77,13 +70,6 @@ abstract class OutputWriterFactory extends Serializable { * executor side. This instance is used to persist rows to this single output file. */ abstract class OutputWriter { - - /** - * The path of the file to be written out. This path should include the staging directory and - * the file name prefix passed into the associated createOutputWriter function. - */ - def path: String - /** * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala index bd56e511d0cc..9ffb20da070e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala @@ -22,12 +22,11 @@ import java.util.{Date, UUID} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat} +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ @@ -35,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SQLExecution, UnsafeKVExternalSorter} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.execution.datasources.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter @@ -54,8 +53,7 @@ object WriteOutput extends Logging { val nonPartitionColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], val isAppend: Boolean, - val path: String, - val outputFormatClass: Class[_ <: OutputFormat[_, _]]) + val path: String) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), @@ -111,31 +109,35 @@ object WriteOutput extends Logging { nonPartitionColumns = dataColumns, bucketSpec = bucketSpec, isAppend = isAppend, - path = outputPath.toString, - outputFormatClass = job.getOutputFormatClass) + path = outputPath.toString) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - val committer = setupDriverCommitter(job, outputPath.toString, isAppend) + val committer = FileCommitProtocol.instantiate( + sparkSession.sessionState.conf.fileCommitProtocolClass, + outputPath.toString, + isAppend) + committer.setupJob(job) try { - sparkSession.sparkContext.runJob(queryExecution.toRdd, + val commitMsgs = sparkSession.sparkContext.runJob(queryExecution.toRdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, sparkStageId = taskContext.stageId(), sparkPartitionId = taskContext.partitionId(), sparkAttemptNumber = taskContext.attemptNumber(), + committer, iterator = iter) }) - committer.commitJob(job) + committer.commitJob(job, commitMsgs) logInfo(s"Job ${job.getJobID} committed.") refreshFunction() } catch { case cause: Throwable => logError(s"Aborting job ${job.getJobID}.", cause) - committer.abortJob(job, JobStatus.State.FAILED) + committer.abortJob(job) throw new SparkException("Job aborted.", cause) } } @@ -147,7 +149,8 @@ object WriteOutput extends Logging { sparkStageId: Int, sparkPartitionId: Int, sparkAttemptNumber: Int, - iterator: Iterator[InternalRow]): Unit = { + committer: FileCommitProtocol, + iterator: Iterator[InternalRow]): TaskCommitMessage = { val jobId = SparkHadoopWriter.createJobID(new Date, sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -166,32 +169,21 @@ object WriteOutput extends Logging { new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } - val committer = newOutputCommitter( - description.outputFormatClass, taskAttemptContext, description.path, description.isAppend) committer.setupTask(taskAttemptContext) - // Figure out where we need to write data to for staging. - // For FileOutputCommitter it has its own staging path called "work path". - val stagingPath = committer match { - case f: FileOutputCommitter => f.getWorkPath.toString - case _ => description.path - } - val writeTask = if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { - new SingleDirectoryWriteTask(description, taskAttemptContext, stagingPath) + new SingleDirectoryWriteTask(description, taskAttemptContext, committer) } else { - new DynamicPartitionWriteTask(description, taskAttemptContext, stagingPath) + new DynamicPartitionWriteTask(description, taskAttemptContext, committer) } try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { - // Execute the task to write rows out + // Execute the task to write rows out and commit the task. writeTask.execute(iterator) writeTask.releaseResources() - - // Commit the task - SparkHadoopMapRedUtil.commitTask(committer, taskAttemptContext, jobId.getId, taskId.getId) + committer.commitTask(taskAttemptContext) })(catchBlock = { // If there is an error, release resource and then abort the task try { @@ -218,7 +210,7 @@ object WriteOutput extends Logging { final def filePrefix(split: Int, uuid: String, bucketId: Option[Int]): String = { val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - f"part-r-$split%05d-$uuid$bucketString" + f"part-$split%05d-$uuid$bucketString" } } @@ -226,14 +218,16 @@ object WriteOutput extends Logging { private class SingleDirectoryWriteTask( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - stagingPath: String) extends ExecuteWriteTask { + committer: FileCommitProtocol) extends ExecuteWriteTask { private[this] var outputWriter: OutputWriter = { - val split = taskAttemptContext.getTaskAttemptID.getTaskID.getId + val tmpFilePath = committer.addTaskTempFile( + taskAttemptContext, + None, + description.outputWriterFactory.getFileExtension(taskAttemptContext)) val outputWriter = description.outputWriterFactory.newInstance( - stagingDir = stagingPath, - fileNamePrefix = filePrefix(split, description.uuid, None), + path = tmpFilePath, dataSchema = description.nonPartitionColumns.toStructType, context = taskAttemptContext) outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) @@ -262,7 +256,7 @@ object WriteOutput extends Logging { private class DynamicPartitionWriteTask( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - stagingPath: String) extends ExecuteWriteTask { + committer: FileCommitProtocol) extends ExecuteWriteTask { // currentWriter is initialized whenever we see a new key private var currentWriter: OutputWriter = _ @@ -302,25 +296,20 @@ object WriteOutput extends Logging { * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet */ private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = { - val path = - if (description.partitionColumns.nonEmpty) { - val partitionPath = partString(key).getString(0) - new Path(stagingPath, partitionPath).toString - } else { - stagingPath - } + val partDir = + if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0)) // If the bucket spec is defined, the bucket column is right after the partition columns val bucketId = if (description.bucketSpec.isDefined) { - Some(key.getInt(description.partitionColumns.length)) + BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length)) } else { - None + "" } + val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext) - val split = taskAttemptContext.getTaskAttemptID.getTaskID.getId + val path = committer.addTaskTempFile(taskAttemptContext, partDir, ext) val newWriter = description.outputWriterFactory.newInstance( - stagingDir = path, - fileNamePrefix = filePrefix(split, description.uuid, bucketId), + path = path, dataSchema = description.nonPartitionColumns.toStructType, context = taskAttemptContext) newWriter.initConverter(description.nonPartitionColumns.toStructType) @@ -402,75 +391,4 @@ object WriteOutput extends Logging { } } } - - private def setupDriverCommitter(job: Job, path: String, isAppend: Boolean): OutputCommitter = { - // Setup IDs - val jobId = SparkHadoopWriter.createJobID(new Date, 0) - val taskId = new TaskID(jobId, TaskType.MAP, 0) - val taskAttemptId = new TaskAttemptID(taskId, 0) - - // Set up the configuration object - job.getConfiguration.set("mapred.job.id", jobId.toString) - job.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - job.getConfiguration.set("mapred.task.id", taskAttemptId.toString) - job.getConfiguration.setBoolean("mapred.task.is.map", true) - job.getConfiguration.setInt("mapred.task.partition", 0) - - val taskAttemptContext = new TaskAttemptContextImpl(job.getConfiguration, taskAttemptId) - val outputCommitter = newOutputCommitter( - job.getOutputFormatClass, taskAttemptContext, path, isAppend) - outputCommitter.setupJob(job) - outputCommitter - } - - private def newOutputCommitter( - outputFormatClass: Class[_ <: OutputFormat[_, _]], - context: TaskAttemptContext, - path: String, - isAppend: Boolean): OutputCommitter = { - val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) - - if (isAppend) { - // If we are appending data to an existing dir, we will only use the output committer - // associated with the file output format since it is not safe to use a custom - // committer for appending. For example, in S3, direct parquet output committer may - // leave partial data in the destination dir when the appending job fails. - // See SPARK-8578 for more details - logInfo( - s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + - "for appending.") - defaultOutputCommitter - } else { - val configuration = context.getConfiguration - val clazz = - configuration.getClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) - - if (clazz != null) { - logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") - - // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat - // has an associated output committer. To override this output committer, - // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. - // If a data source needs to override the output committer, it needs to set the - // output committer in prepareForWrite method. - if (classOf[FileOutputCommitter].isAssignableFrom(clazz)) { - // The specified output committer is a FileOutputCommitter. - // So, we will use the FileOutputCommitter-specified constructor. - val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - ctor.newInstance(new Path(path), context) - } else { - // The specified output committer is just an OutputCommitter. - // So, we will use the no-argument constructor. - val ctor = clazz.getDeclaredConstructor() - ctor.newInstance() - } - } else { - // If output committer class is not set, we will use the one associated with the - // file output format. - logInfo( - s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") - defaultOutputCommitter - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index a35cfdb2c234..a249b9d9d59b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -171,26 +171,23 @@ object CSVRelation extends Logging { private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new CsvOutputWriter(stagingDir, fileNamePrefix, dataSchema, context, params) + new CsvOutputWriter(path, dataSchema, context, params) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".csv" + TextOutputWriter.getCompressionExtension(context) } } private[csv] class CsvOutputWriter( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { - override val path: String = { - val compressionExtension = TextOutputWriter.getCompressionExtension(context) - new Path(stagingDir, fileNamePrefix + ".csv" + compressionExtension).toString - } - // create the Generator without separator inserted between 2 records private[this] val text = new Text() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 651fa78a4e92..5a409c04c929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -83,11 +83,14 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(stagingDir, parsedOptions, fileNamePrefix, dataSchema, context) + new JsonOutputWriter(path, parsedOptions, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".json" + TextOutputWriter.getCompressionExtension(context) } } } @@ -154,18 +157,12 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { } private[json] class JsonOutputWriter( - stagingDir: String, + path: String, options: JSONOptions, - fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter with Logging { - override val path: String = { - val compressionExtension = TextOutputWriter.getCompressionExtension(context) - new Path(stagingDir, fileNamePrefix + ".json" + compressionExtension).toString - } - private[this] val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 502dd0e8d4cf..77c83ba38efe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -33,6 +33,7 @@ import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ +import org.apache.parquet.hadoop.codec.CodecConfig import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType import org.slf4j.bridge.SLF4JBridgeHandler @@ -133,10 +134,13 @@ class ParquetFileFormat new OutputWriterFactory { override def newInstance( path: String, - fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, fileNamePrefix, context) + new ParquetOutputWriter(path, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + CodecConfig.from(context).getCodec.getExtension + ".parquet" } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 1300069c42b0..92d4f27be3fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -89,7 +89,7 @@ private[parquet] class ParquetOutputWriterFactory( * Returns a [[OutputWriter]] that writes data to the give path without using * [[OutputCommitter]]. */ - override def newWriter(path1: String): OutputWriter = new OutputWriter { + override def newWriter(path: String): OutputWriter = new OutputWriter { // Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter private val hadoopTaskAttemptId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0) @@ -99,8 +99,6 @@ private[parquet] class ParquetOutputWriterFactory( // Instance of ParquetRecordWriter that does not use OutputCommitter private val recordWriter = createNoCommitterRecordWriter(path, hadoopAttemptContext) - override def path: String = path1 - override def write(row: Row): Unit = { throw new UnsupportedOperationException("call writeInternal") } @@ -127,27 +125,22 @@ private[parquet] class ParquetOutputWriterFactory( /** Disable the use of the older API. */ override def newInstance( path: String, - fileNamePrefix: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { throw new UnsupportedOperationException("this version of newInstance not supported for " + "ParquetOutputWriterFactory") } + + override def getFileExtension(context: TaskAttemptContext): String = { + CodecConfig.from(context).getCodec.getExtension + ".parquet" + } } // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[parquet] class ParquetOutputWriter( - stagingDir: String, - fileNamePrefix: String, - context: TaskAttemptContext) +private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { - override val path: String = { - val filename = fileNamePrefix + CodecConfig.from(context).getCodec.getExtension + ".parquet" - new Path(stagingDir, filename).toString - } - private val recordWriter: RecordWriter[Void, InternalRow] = { new ParquetOutputFormat[InternalRow]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index d40b5725199a..8e043960326d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -75,11 +75,14 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new TextOutputWriter(stagingDir, fileNamePrefix, dataSchema, context) + new TextOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + ".txt" + TextOutputWriter.getCompressionExtension(context) } } } @@ -124,17 +127,11 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } class TextOutputWriter( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { - override val path: String = { - val compressionExtension = TextOutputWriter.getCompressionExtension(context) - new Path(stagingDir, fileNamePrefix + ".txt" + compressionExtension).toString - } - private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index dc31f3bc323f..00de7df78254 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.execution.datasources.MapReduceFileCommitterProtocol import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -240,9 +241,8 @@ object SQLConf { val PARQUET_OUTPUT_COMMITTER_CLASS = SQLConfigBuilder("spark.sql.parquet.output.committer.class") .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + - "of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " + - "option must be set in Hadoop Configuration. 2. This option overrides " + - "\"spark.sql.sources.outputCommitterClass\".") + "of org.apache.parquet.hadoop.ParquetOutputCommitter.") + .internal() .stringConf .createWithDefault(classOf[ParquetOutputCommitter].getName) @@ -375,16 +375,17 @@ object SQLConf { .booleanConf .createWithDefault(true) - // The output committer class used by HadoopFsRelation. The specified class needs to be a + // The output committer class used by data sources. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. - // - // NOTE: - // - // 1. Instead of SQLConf, this option *must be set in Hadoop Configuration*. - // 2. This option can be overridden by "spark.sql.parquet.output.committer.class". val OUTPUT_COMMITTER_CLASS = SQLConfigBuilder("spark.sql.sources.outputCommitterClass").internal().stringConf.createOptional + val FILE_COMMIT_PROTOCOL_CLASS = + SQLConfigBuilder("spark.sql.sources.commitProtocolClass") + .internal() + .stringConf + .createWithDefault(classOf[MapReduceFileCommitterProtocol].getName) + val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = SQLConfigBuilder("spark.sql.sources.parallelPartitionDiscovery.threshold") .doc("The maximum number of files allowed for listing files at driver side. If the number " + @@ -518,6 +519,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STREAMING_FILE_COMMIT_PROTOCOL_CLASS = + SQLConfigBuilder("spark.sql.streaming.commitProtocolClass") + .internal() + .stringConf + .createWithDefault(classOf[MapReduceFileCommitterProtocol].getName) + val FILE_SINK_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion") .internal() .doc("Whether to delete the expired log files in file stream sink.") @@ -631,6 +638,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) + def streamingFileCommitProtocolClass: String = getConf(STREAMING_FILE_COMMIT_PROTOCOL_CLASS) + def fileSinkLogDeletion: Boolean = getConf(FILE_SINK_LOG_DELETION) def fileSinkLogCompactInterval: Int = getConf(FILE_SINK_LOG_COMPACT_INTERVAL) @@ -741,6 +750,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def partitionColumnTypeInferenceEnabled: Boolean = getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) + def fileCommitProtocolClass: String = getConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS) + def parallelPartitionDiscoveryThreshold: Int = getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index eba7aa386ade..7c519a074317 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -83,11 +83,19 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new OrcOutputWriter(stagingDir, fileNamePrefix, dataSchema, context) + new OrcOutputWriter(path, dataSchema, context) + } + + override def getFileExtension(context: TaskAttemptContext): String = { + val compressionExtension: String = { + val name = context.getConfiguration.get(OrcRelation.ORC_COMPRESSION) + OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") + } + + compressionExtension + ".orc" } } } @@ -210,23 +218,11 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) } private[orc] class OrcOutputWriter( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { - override val path: String = { - val compressionExtension: String = { - val name = context.getConfiguration.get(OrcRelation.ORC_COMPRESSION) - OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") - } - // It has the `.orc` extension at the end because (de)compression tools - // such as gunzip would not be able to decompress this as the compression - // is not applied on this whole file but on each "stream" in ORC format. - new Path(stagingDir, fileNamePrefix + compressionExtension + ".orc").toString - } - private[this] val serializer = new OrcSerializer(dataSchema, context.getConfiguration) // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala index 731540db17ee..abc7c8cc4db8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.sources -import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext @@ -40,19 +39,16 @@ class CommitFailureTestSource extends SimpleTextSource { dataSchema: StructType): OutputWriterFactory = new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(stagingDir, fileNamePrefix, context) { + new SimpleTextOutputWriter(path, context) { var failed = false TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => failed = true SimpleTextRelation.callbackCalled = true } - override val path: String = new Path(stagingDir, fileNamePrefix).toString - override def write(row: Row): Unit = { if (SimpleTextRelation.failWriter) { sys.error("Intentional task writer failure for testing purpose.") @@ -67,6 +63,8 @@ class CommitFailureTestSource extends SimpleTextSource { } } } + + override def getFileExtension(context: TaskAttemptContext): String = "" } override def shortName(): String = "commit-failure-test" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 9896b9bde99c..64d0ecbeefc9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -51,12 +51,13 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { SimpleTextRelation.lastHadoopConf = Option(job.getConfiguration) new OutputWriterFactory { override def newInstance( - stagingDir: String, - fileNamePrefix: String, + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(stagingDir, fileNamePrefix, context) + new SimpleTextOutputWriter(path, context) } + + override def getFileExtension(context: TaskAttemptContext): String = "" } } @@ -120,14 +121,11 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { } } -class SimpleTextOutputWriter( - stagingDir: String, fileNamePrefix: String, context: TaskAttemptContext) +class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { - override val path: String = new Path(stagingDir, fileNamePrefix).toString - private val recordWriter: RecordWriter[NullWritable, Text] = - new AppendingTextOutputFormat(new Path(stagingDir), fileNamePrefix).getRecordWriter(context) + new AppendingTextOutputFormat(path).getRecordWriter(context) override def write(row: Row): Unit = { val serialized = row.toSeq.map { v => @@ -141,15 +139,14 @@ class SimpleTextOutputWriter( } } -class AppendingTextOutputFormat(stagingDir: Path, fileNamePrefix: String) - extends TextOutputFormat[NullWritable, Text] { +class AppendingTextOutputFormat(path: String) extends TextOutputFormat[NullWritable, Text] { val numberFormat = NumberFormat.getInstance() numberFormat.setMinimumIntegerDigits(5) numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(stagingDir, fileNamePrefix) + new Path(path) } }