diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 81494b167af50..51d9e2e967990 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -89,7 +89,7 @@ private[spark] case class PythonFunction( private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) /** Thrown for exceptions in user Python code. */ -private[spark] class PythonException(msg: String, cause: Exception) +private[spark] class PythonException(msg: String, cause: Throwable) extends RuntimeException(msg, cause) /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f73e95eac8f79..2e59723648b82 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.internal.Logging @@ -165,15 +166,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { - @volatile private var _exception: Exception = null + @volatile private var _exception: Throwable = null private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) setDaemon(true) - /** Contains the exception thrown while writing the parent iterator to the Python process. */ - def exception: Option[Exception] = Option(_exception) + /** Contains the throwable thrown while writing the parent iterator to the Python process. */ + def exception: Option[Throwable] = Option(_exception) /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ def shutdownOnTaskCompletion() { @@ -347,18 +348,21 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() } catch { - case e: Exception if context.isCompleted || context.isInterrupted => - logDebug("Exception thrown after task completion (likely due to cleanup)", e) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - - case e: Exception => - // We must avoid throwing exceptions here, because the thread uncaught exception handler - // will kill the whole executor (see org.apache.spark.executor.Executor). - _exception = e - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) + case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => + if (context.isCompleted || context.isInterrupted) { + logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + + "cleanup)", t) + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } else { + // We must avoid throwing exceptions/NonFatals here, because the thread uncaught + // exception handler will kill the whole executor (see + // org.apache.spark.executor.Executor). + _exception = t + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } } } }