diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index f245d2d4e4074..4395128e85a4b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -276,8 +276,14 @@ class HadoopMapReduceCommitProtocol( override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID logTrace(s"Commit task ${attemptId}") + val disableCommitCoordination = + taskContext.getConfiguration.get("spark.test.disableCommitCoordination") == "true" SparkHadoopMapRedUtil.commitTask( - committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) + committer, + taskContext, + attemptId.getJobID.getId, + attemptId.getTaskID.getId, + disableCommitCoordination) new TaskCommitMessage(addedAbsPathFiles.toMap -> partitionPaths.toSet) } diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index a5395aa01a9ba..a2373cc98e71f 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -42,7 +42,8 @@ object SparkHadoopMapRedUtil extends Logging { committer: MapReduceOutputCommitter, mrTaskContext: MapReduceTaskAttemptContext, jobId: Int, - splitId: Int): Unit = { + splitId: Int, + disableCommitCoordination: Boolean): Unit = { val mrTaskAttemptID = mrTaskContext.getTaskAttemptID @@ -71,7 +72,7 @@ object SparkHadoopMapRedUtil extends Logging { sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", defaultValue = true) } - if (shouldCoordinateWithDriver) { + if (shouldCoordinateWithDriver && !disableCommitCoordination) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator val ctx = TaskContext.get() val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), ctx.stageAttemptNumber(), @@ -96,4 +97,12 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") } } + + def commitTask( + committer: MapReduceOutputCommitter, + mrTaskContext: MapReduceTaskAttemptContext, + jobId: Int, + splitId: Int): Unit = { + commitTask(committer, mrTaskContext, jobId, splitId, disableCommitCoordination = false) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 9fb490dd823ad..0a33a3a1b5c26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -1207,8 +1207,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } test("SPARK-7837 Do not close output writer twice when commitTask() fails") { - withSQLConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> - classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName) { + withSQLConf( + SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> + classOf[SQLHadoopMapReduceCommitProtocol].getCanonicalName, + "spark.test.disableCommitCoordination" -> "true") { // Using a output committer that always fail when committing a task, so that both // `commitTask()` and `abortTask()` are invoked. val extraOptions = Map[String, String](