diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 7f80d5e66bdd..d3cc5ed10799 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -144,7 +144,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent( handle.shuffleId, _ => new OpenHashSet[Long](16)) - mapTaskIds.synchronized { mapTaskIds.add(context.taskAttemptId()) } + mapTaskIds.synchronized { mapTaskIds.add(mapId) } val env = SparkEnv.get handle match { case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 5ea43e2dac7d..56684d9b0327 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,9 +17,14 @@ package org.apache.spark +import java.io.File import java.util.{Locale, Properties} import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter import org.scalatest.matchers.must.Matchers import org.scalatest.matchers.should.Matchers._ @@ -29,7 +34,7 @@ import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd} -import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} import org.apache.spark.util.MutablePair @@ -419,6 +424,31 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC manager.unregisterShuffle(0) } + + test("SPARK-34541: shuffle can be removed") { + withTempDir { tmpDir => + def getAllFiles: Set[File] = + FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + conf.set("spark.local.dir", tmpDir.getAbsolutePath) + sc = new SparkContext("local", "test", conf) + // For making the taskAttemptId starts from 1. + sc.parallelize(1 to 10).count() + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + // Create a shuffleRdd + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new JavaSerializer(conf)) + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } + } } /**