Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,11 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration

private[libsvm] class LibSVMOutputWriter(
stagingDir: String,
fileNamePrefix: String,
path: String,
dataSchema: StructType,
context: TaskAttemptContext)
extends OutputWriter {

override val path: String = {
val compressionExtension = TextOutputWriter.getCompressionExtension(context)
new Path(stagingDir, fileNamePrefix + ".libsvm" + compressionExtension).toString
}

private[this] val buffer = new Text()

private val recordWriter: RecordWriter[NullWritable, Text] = {
Expand Down Expand Up @@ -135,11 +129,14 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
dataSchema: StructType): OutputWriterFactory = {
new OutputWriterFactory {
override def newInstance(
stagingDir: String,
fileNamePrefix: String,
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new LibSVMOutputWriter(stagingDir, fileNamePrefix, dataSchema, context)
new LibSVMOutputWriter(path, dataSchema, context)
}

override def getFileExtension(context: TaskAttemptContext): String = {
".libsvm" + TextOutputWriter.getCompressionExtension(context)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources

import java.util.{Date, UUID}

import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl

import org.apache.spark.SparkHadoopWriter
import org.apache.spark.internal.Logging
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils


object FileCommitProtocol {
class TaskCommitMessage(obj: Any) extends Serializable

object EmptyTaskCommitMessage extends TaskCommitMessage(Unit)

/**
* Instantiates a FileCommitProtocol using the given className.
*/
def instantiate(className: String, outputPath: String, isAppend: Boolean): FileCommitProtocol = {
try {
val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]]

// First try the one with argument (outputPath: String, isAppend: Boolean).
// If that doesn't exist, try the one with (outputPath: String).
try {
val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[Boolean])
ctor.newInstance(outputPath, isAppend.asInstanceOf[java.lang.Boolean])
} catch {
case _: NoSuchMethodException =>
val ctor = clazz.getDeclaredConstructor(classOf[String])
ctor.newInstance(outputPath)
}
} catch {
case e: ClassNotFoundException =>
throw e
}
}
}


/**
* An interface to define how a Spark job commits its outputs. Implementations must be serializable,
* as the committer instance instantiated on the driver will be used for tasks on executors.
*
* The proper call sequence is:
*
* 1. Driver calls setupJob.
* 2. As part of each task's execution, executor calls setupTask and then commitTask
* (or abortTask if task failed).
* 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job
* failed to execute (e.g. too many failed tasks), the job should call abortJob.
*/
abstract class FileCommitProtocol {
import FileCommitProtocol._

/**
* Setups up a job. Must be called on the driver before any other methods can be invoked.
*/
def setupJob(jobContext: JobContext): Unit

/**
* Commits a job after the writes succeed. Must be called on the driver.
*/
def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit

/**
* Aborts a job after the writes fail. Must be called on the driver.
*
* Calling this function is a best-effort attempt, because it is possible that the driver
* just crashes (or killed) before it can call abort.
*/
def abortJob(jobContext: JobContext): Unit

/**
* Sets up a task within a job.
* Must be called before any other task related methods can be invoked.
*/
def setupTask(taskContext: TaskAttemptContext): Unit

/**
* Notifies the commit protocol to add a new file, and gets back the full path that should be
* used. Must be called on the executors when running tasks.
*
* Note that the returned temp file may have an arbitrary path. The commit protocol only
* promises that the file will be at the location specified by the arguments after job commit.
*
* A full file path consists of the following parts:
* 1. the base path
* 2. some sub-directory within the base path, used to specify partitioning
* 3. file prefix, usually some unique job id with the task id
* 4. bucket id
* 5. source specific file extension, e.g. ".snappy.parquet"
*
* The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest
* are left to the commit protocol implementation to decide.
*/
def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String

/**
* Commits a task after the writes succeed. Must be called on the executors when running tasks.
*/
def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage

/**
* Aborts a task after the writes have failed. Must be called on the executors when running tasks.
*
* Calling this function is a best-effort attempt, because it is possible that the executor
* just crashes (or killed) before it can call abort.
*/
def abortTask(taskContext: TaskAttemptContext): Unit
}


/**
* An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter
* (from the newer mapreduce API, not the old mapred API).
*
* Unlike Hadoop's OutputCommitter, this implementation is serializable.
*/
class HadoopCommitProtocolWrapper(path: String, isAppend: Boolean)
extends FileCommitProtocol with Serializable with Logging {

import FileCommitProtocol._

/** OutputCommitter from Hadoop is not serializable so marking it transient. */
@transient private var committer: OutputCommitter = _

/** UUID used to identify the job in file name. */
private val uuid: String = UUID.randomUUID().toString

private def setupCommitter(context: TaskAttemptContext): Unit = {
committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context)

if (!isAppend) {
// If we are appending data to an existing dir, we will only use the output committer
// associated with the file output format since it is not safe to use a custom
// committer for appending. For example, in S3, direct parquet output committer may
// leave partial data in the destination dir when the appending job fails.
// See SPARK-8578 for more details.
val configuration = context.getConfiguration
val clazz =
configuration.getClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter])

if (clazz != null) {
logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}")

// Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat
// has an associated output committer. To override this output committer,
// we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS.
// If a data source needs to override the output committer, it needs to set the
// output committer in prepareForWrite method.
if (classOf[FileOutputCommitter].isAssignableFrom(clazz)) {
// The specified output committer is a FileOutputCommitter.
// So, we will use the FileOutputCommitter-specified constructor.
val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext])
committer = ctor.newInstance(new Path(path), context)
} else {
// The specified output committer is just an OutputCommitter.
// So, we will use the no-argument constructor.
val ctor = clazz.getDeclaredConstructor()
committer = ctor.newInstance()
}
}
}
logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}")
}

override def newTaskTempFile(
taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
// The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
// Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
// the file name is fine and won't overflow.
val split = taskContext.getTaskAttemptID.getTaskID.getId
val filename = f"part-$split%05d-$uuid$ext"

val stagingDir: String = committer match {
// For FileOutputCommitter it has its own staging path called "work path".
case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path)
case _ => path
}

dir.map { d =>
new Path(new Path(stagingDir, d), filename).toString
}.getOrElse {
new Path(stagingDir, filename).toString
}
}

override def setupJob(jobContext: JobContext): Unit = {
// Setup IDs
val jobId = SparkHadoopWriter.createJobID(new Date, 0)
val taskId = new TaskID(jobId, TaskType.MAP, 0)
val taskAttemptId = new TaskAttemptID(taskId, 0)

// Set up the configuration object
jobContext.getConfiguration.set("mapred.job.id", jobId.toString)
jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString)
jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString)
jobContext.getConfiguration.setBoolean("mapred.task.is.map", true)
jobContext.getConfiguration.setInt("mapred.task.partition", 0)

val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId)
setupCommitter(taskAttemptContext)

committer.setupJob(jobContext)
}

override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = {
committer.commitJob(jobContext)
}

override def abortJob(jobContext: JobContext): Unit = {
committer.abortJob(jobContext, JobStatus.State.FAILED)
}

override def setupTask(taskContext: TaskAttemptContext): Unit = {
setupCommitter(taskContext)
committer.setupTask(taskContext)
}

override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = {
val attemptId = taskContext.getTaskAttemptID
SparkHadoopMapRedUtil.commitTask(
committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId)
EmptyTaskCommitMessage
}

override def abortTask(taskContext: TaskAttemptContext): Unit = {
committer.abortTask(taskContext)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,21 @@ import org.apache.spark.sql.types.StructType
* to executor side to create actual [[OutputWriter]]s on the fly.
*/
abstract class OutputWriterFactory extends Serializable {

/** Returns the file extension to be used when writing files out. */
def getFileExtension(context: TaskAttemptContext): String

/**
* When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side
* to instantiate new [[OutputWriter]]s.
*
* @param stagingDir Base path (directory) of the file to which this [[OutputWriter]] is supposed
* to write. Note that this may not point to the final output file. For
* example, `FileOutputFormat` writes to temporary directories and then merge
* written files back to the final destination. In this case, `path` points to
* a temporary output file under the temporary directory.
* @param fileNamePrefix Prefix of the file name. The returned OutputWriter must make sure this
* prefix is used in the actual file name. For example, if the prefix is
* "part-1-2-3", then the file name must start with "part_1_2_3" but can
* end in arbitrary extension that is deterministic given the configuration
* (i.e. the suffix extension should not depend on any task id, attempt id,
* or partition id).
* @param path Path to write the file.
* @param dataSchema Schema of the rows to be written. Partition columns are not included in the
* schema if the relation being written is partitioned.
* @param context The Hadoop MapReduce task context.
*/
def newInstance(
stagingDir: String,
fileNamePrefix: String,
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter

Expand All @@ -77,13 +70,6 @@ abstract class OutputWriterFactory extends Serializable {
* executor side. This instance is used to persist rows to this single output file.
*/
abstract class OutputWriter {

/**
* The path of the file to be written out. This path should include the staging directory and
* the file name prefix passed into the associated createOutputWriter function.
*/
def path: String

/**
* Persists a single row. Invoked on the executor side. When writing to dynamically partitioned
* tables, dynamic partition columns are not included in rows to be written.
Expand Down
Loading