diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 02b28b72fb0e..f1daf62ad4d1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -113,7 +113,7 @@ private[spark] class PipedRDD[T: ClassTag]( val childThreadException = new AtomicReference[Throwable](null) // Start a thread to print the process's stderr to ours - new Thread(s"stderr reader for $command") { + val stderrReaderThread = new Thread(s"${PipedRDD.STDERR_READER_THREAD_PREFIX} $command") { override def run(): Unit = { val err = proc.getErrorStream try { @@ -128,10 +128,11 @@ private[spark] class PipedRDD[T: ClassTag]( err.close() } } - }.start() + } + stderrReaderThread.start() // Start a thread to feed the process input from our parent's iterator - new Thread(s"stdin writer for $command") { + val stdinWriterThread = new Thread(s"${PipedRDD.STDIN_WRITER_THREAD_PREFIX} $command") { override def run(): Unit = { TaskContext.setTaskContext(context) val out = new PrintWriter(new BufferedWriter( @@ -156,7 +157,28 @@ private[spark] class PipedRDD[T: ClassTag]( out.close() } } - }.start() + } + stdinWriterThread.start() + + // interrupts stdin writer and stderr reader threads when the corresponding task is finished. + // Otherwise, these threads could outlive the task's lifetime. For example: + // val pipeRDD = sc.range(1, 100).pipe(Seq("cat")) + // val abnormalRDD = pipeRDD.mapPartitions(_ => Iterator.empty) + // the iterator generated by PipedRDD is never involved. If the parent RDD's iterator takes a + // long time to generate(ShuffledRDD's shuffle operation for example), the stdin writer thread + // may consume significant memory and CPU time even if task is already finished. + context.addTaskCompletionListener[Unit] { _ => + if (proc.isAlive) { + proc.destroy() + } + + if (stdinWriterThread.isAlive) { + stdinWriterThread.interrupt() + } + if (stderrReaderThread.isAlive) { + stderrReaderThread.interrupt() + } + } // Return an iterator that read lines from the process's stdout val lines = Source.fromInputStream(proc.getInputStream)(encoding).getLines @@ -219,4 +241,7 @@ private object PipedRDD { } buf } + + val STDIN_WRITER_THREAD_PREFIX = "stdin writer for" + val STDERR_READER_THREAD_PREFIX = "stderr reader for" } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 86f7c08eddcb..fbc4db59f8e3 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -141,7 +141,14 @@ final class ShuffleBlockFetcherIterator( /** * Whether the iterator is still active. If isZombie is true, the callback interface will no - * longer place fetched blocks into [[results]]. + * longer place fetched blocks into [[results]] and the iterator is marked as fully consumed. + * + * When the iterator is inactive, [[hasNext]] and [[next]] calls will honor that as there are + * cases the iterator is still being consumed. For example, ShuffledRDD + PipedRDD if the + * subprocess command is failed. The task will be marked as failed, then the iterator will be + * cleaned up at task completion, the [[next]] call (called in the stdin writer thread of + * PipedRDD if not exited yet) may hang at [[results.take]]. The defensive check in [[hasNext]] + * and [[next]] reduces the possibility of such race conditions. */ @GuardedBy("this") private[this] var isZombie = false @@ -372,7 +379,7 @@ final class ShuffleBlockFetcherIterator( logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime)) } - override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + override def hasNext: Boolean = !isZombie && (numBlocksProcessed < numBlocksToFetch) /** * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers @@ -384,7 +391,7 @@ final class ShuffleBlockFetcherIterator( */ override def next(): (BlockId, InputStream) = { if (!hasNext) { - throw new NoSuchElementException + throw new NoSuchElementException() } numBlocksProcessed += 1 @@ -395,7 +402,7 @@ final class ShuffleBlockFetcherIterator( // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch // is also corrupt, so the previous stage could be retried. // For local shuffle block, throw FailureFetchResult for the first IOException. - while (result == null) { + while (!isZombie && result == null) { val startFetchWait = System.currentTimeMillis() result = results.take() val stopFetchWait = System.currentTimeMillis() @@ -489,6 +496,9 @@ final class ShuffleBlockFetcherIterator( fetchUpToMaxBytes() } + if (result == null) { // the iterator is already closed/cleaned up. + throw new NoSuchElementException() + } currentResult = result.asInstanceOf[SuccessFetchResult] (currentResult.blockId, new BufferReleasingInputStream(input, this)) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 1a0eb250e7cd..69739a2e5848 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.rdd import java.io.File +import scala.collection.JavaConverters._ import scala.collection.Map import scala.io.Codec @@ -83,6 +84,29 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("stdin writer thread should be exited when task is finished") { + assume(TestUtils.testCommandAvailable("cat")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 1).map { x => + val obj = new Object() + obj.synchronized { + obj.wait() // make the thread waits here. + } + x + } + + val piped = nums.pipe(Seq("cat")) + + val result = piped.mapPartitions(_ => Array.emptyIntArray.iterator) + + assert(result.collect().length === 0) + + // collect stderr writer threads + val stderrWriterThread = Thread.getAllStackTraces.keySet().asScala + .find { _.getName.startsWith(PipedRDD.STDIN_WRITER_THREAD_PREFIX) } + + assert(stderrWriterThread.isEmpty) + } + test("advanced pipe") { assume(TestUtils.testCommandAvailable("cat")) val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 6b83243fe496..98fe9663b621 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -217,6 +217,65 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release() } + test("iterator is all consumed if task completes early") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) + + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + Future { + // Return the first two blocks, and wait till task completion before returning the last + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) + sem.acquire() + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) + } + } + }) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) + + + assert(iterator.hasNext) + iterator.next() + + taskContext.markTaskCompleted(None) + sem.release() + assert(iterator.hasNext === false) + } + test("fail all blocks if any of the remote request fails") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1)