Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ private[spark] class PipedRDD[T: ClassTag](
val childThreadException = new AtomicReference[Throwable](null)

// Start a thread to print the process's stderr to ours
new Thread(s"stderr reader for $command") {
val stderrReaderThread = new Thread(s"${PipedRDD.STDERR_READER_THREAD_PREFIX} $command") {
override def run(): Unit = {
val err = proc.getErrorStream
try {
Expand All @@ -128,10 +128,11 @@ private[spark] class PipedRDD[T: ClassTag](
err.close()
}
}
}.start()
}
stderrReaderThread.start()

// Start a thread to feed the process input from our parent's iterator
new Thread(s"stdin writer for $command") {
val stdinWriterThread = new Thread(s"${PipedRDD.STDIN_WRITER_THREAD_PREFIX} $command") {
override def run(): Unit = {
TaskContext.setTaskContext(context)
val out = new PrintWriter(new BufferedWriter(
Expand All @@ -156,7 +157,28 @@ private[spark] class PipedRDD[T: ClassTag](
out.close()
}
}
}.start()
}
stdinWriterThread.start()

// interrupts stdin writer and stderr reader threads when the corresponding task is finished.
// Otherwise, these threads could outlive the task's lifetime. For example:
// val pipeRDD = sc.range(1, 100).pipe(Seq("cat"))
// val abnormalRDD = pipeRDD.mapPartitions(_ => Iterator.empty)
// the iterator generated by PipedRDD is never involved. If the parent RDD's iterator takes a
// long time to generate(ShuffledRDD's shuffle operation for example), the stdin writer thread
// may consume significant memory and CPU time even if task is already finished.
context.addTaskCompletionListener[Unit] { _ =>
if (proc.isAlive) {
proc.destroy()
}

if (stdinWriterThread.isAlive) {
stdinWriterThread.interrupt()
}
if (stderrReaderThread.isAlive) {
stderrReaderThread.interrupt()
}
}

// Return an iterator that read lines from the process's stdout
val lines = Source.fromInputStream(proc.getInputStream)(encoding).getLines
Expand Down Expand Up @@ -219,4 +241,7 @@ private object PipedRDD {
}
buf
}

val STDIN_WRITER_THREAD_PREFIX = "stdin writer for"
val STDERR_READER_THREAD_PREFIX = "stderr reader for"
}
24 changes: 24 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.rdd

import java.io.File

import scala.collection.JavaConverters._
import scala.collection.Map
import scala.io.Codec

Expand Down Expand Up @@ -83,6 +84,29 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
}
}

test("stdin writer thread should be exited when task is finished") {
assume(TestUtils.testCommandAvailable("cat"))
val nums = sc.makeRDD(Array(1, 2, 3, 4), 1).map { x =>
val obj = new Object()
obj.synchronized {
obj.wait() // make the thread waits here.
}
x
}

val piped = nums.pipe(Seq("cat"))

val result = piped.mapPartitions(_ => Array.emptyIntArray.iterator)

assert(result.collect().length === 0)

// collect stderr writer threads
val stderrWriterThread = Thread.getAllStackTraces.keySet().asScala
.find { _.getName.startsWith(PipedRDD.STDIN_WRITER_THREAD_PREFIX) }

assert(stderrWriterThread.isEmpty)
}

test("advanced pipe") {
assume(TestUtils.testCommandAvailable("cat"))
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
Expand Down