diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e2c6a912bc270..d2c7067e59612 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2068,7 +2068,20 @@ class SparkContext(config: SparkConf) extends Logging { /** * Shut down the SparkContext. */ - def stop(): Unit = { + def stop(): Unit = stop(0) + + /** + * Shut down the SparkContext with exit code that will passed to scheduler backend. + * In client mode, client side may call `SparkContext.stop()` to clean up but exit with + * code not equal to 0. This behavior cause resource scheduler such as `ApplicationMaster` + * exit with success status but client side exited with failed status. Spark can call + * this method to stop SparkContext and pass client side correct exit code to scheduler backend. + * Then scheduler backend should send the exit code to corresponding resource scheduler + * to keep consistent. + * + * @param exitCode Specified exit code that will passed to scheduler backend in client mode. + */ + def stop(exitCode: Int): Unit = { if (LiveListenerBus.withinListenerThread.value) { throw new SparkException(s"Cannot stop SparkContext within listener bus thread.") } @@ -2101,7 +2114,7 @@ class SparkContext(config: SparkConf) extends Logging { } if (_dagScheduler != null) { Utils.tryLogNonFatalError { - _dagScheduler.stop() + _dagScheduler.stop(exitCode) } _dagScheduler = null } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4efce34b18c29..fb3512619d87f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -2889,7 +2889,7 @@ private[spark] class DAGScheduler( listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) } - def stop(): Unit = { + def stop(exitCode: Int = 0): Unit = { Utils.tryLogNonFatalError { messageScheduler.shutdownNow() } @@ -2900,7 +2900,7 @@ private[spark] class DAGScheduler( eventProcessLoop.stop() } Utils.tryLogNonFatalError { - taskScheduler.stop() + taskScheduler.stop(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index b2acdb3e12a6d..56666dcaccf03 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -30,6 +30,7 @@ private[spark] trait SchedulerBackend { def start(): Unit def stop(): Unit + def stop(exitCode: Int): Unit = stop() /** * Update the current offers and schedule tasks */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 0fa80bbafdedd..5613966e8f5e9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -49,7 +49,7 @@ private[spark] trait TaskScheduler { def postStartHook(): Unit = { } // Disconnect from the cluster. - def stop(): Unit + def stop(exitCode: Int = 0): Unit // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1eb588124a7c0..80b66c4f675c6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -971,13 +971,13 @@ private[spark] class TaskSchedulerImpl( } } - override def stop(): Unit = { + override def stop(exitCode: Int = 0): Unit = { Utils.tryLogNonFatalError { speculationScheduler.shutdown() } if (backend != null) { Utils.tryLogNonFatalError { - backend.stop() + backend.stop(exitCode) } } if (taskResultGetter != null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 61ee865c0fcb4..6d2befec155e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -151,7 +151,7 @@ private[spark] object CoarseGrainedClusterMessages { case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage // Used internally by executors to shut themselves down. - case object Shutdown extends CoarseGrainedClusterMessage + case class Shutdown(exitCode: Int = 0) extends CoarseGrainedClusterMessage // The message to check if `CoarseGrainedSchedulerBackend` thinks the executor is alive or not. case class IsExecutorAlive(executorId: String) extends CoarseGrainedClusterMessage diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3b57535ea3e9c..04b335987d21c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -176,7 +176,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti override def schedulingMode: SchedulingMode = SchedulingMode.FIFO override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start() = {} - override def stop() = {} + override def stop(exitCode: Int) = {} override def executorHeartbeatReceived( execId: String, accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], @@ -846,7 +846,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti override def schedulingMode: SchedulingMode = SchedulingMode.FIFO override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start(): Unit = {} - override def stop(): Unit = {} + override def stop(exitCode: Int): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = { taskSets += taskSet } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index 08191d09a9f2d..a30cb521bf484 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -80,7 +80,7 @@ private class DummyTaskScheduler extends TaskScheduler { override def schedulingMode: SchedulingMode = SchedulingMode.FIFO override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start(): Unit = {} - override def stop(): Unit = {} + override def stop(exitCode: Int): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = {} override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {} override def killTaskAttempt( diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 27f1b9e3e37df..ee2e630bdcee8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -100,7 +100,15 @@ object MimaExcludes { // [SPARK-37935][SQL] Eliminate separate error sub-classes fields ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkException.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.AnalysisException.this") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.AnalysisException.this"), + + // [SPARK-38270][SQL] Spark SQL CLI's AM should keep same exit code with client side + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#Shutdown.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#Shutdown.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#Shutdown.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#Shutdown.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#Shutdown.canEqual"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#Shutdown.toString") ) // Defulat exclude rules diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 8f8c08fbe74ba..a7676fe24f64c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -790,6 +790,8 @@ private[spark] class ApplicationMaster( private class AMEndpoint(override val rpcEnv: RpcEnv, driver: RpcEndpointRef) extends RpcEndpoint with Logging { @volatile private var shutdown = false + @volatile private var exitCode = 0 + private val clientModeTreatDisconnectAsFailed = sparkConf.get(AM_CLIENT_MODE_TREAT_DISCONNECT_AS_FAILED) @@ -810,7 +812,9 @@ private[spark] class ApplicationMaster( case UpdateDelegationTokens(tokens) => SparkHadoopUtil.get.addDelegationTokens(tokens, sparkConf) - case Shutdown => shutdown = true + case Shutdown(code) => + exitCode = code + shutdown = true } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -854,8 +858,13 @@ private[spark] class ApplicationMaster( // This avoids potentially reporting incorrect exit codes if the driver fails if (!(isClusterMode || sparkConf.get(YARN_UNMANAGED_AM))) { if (shutdown || !clientModeTreatDisconnectAsFailed) { - logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + if (exitCode == 0) { + logInfo(s"Driver terminated or disconnected! Shutting down. $remoteAddress") + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } else { + logError(s"Driver terminated with exit code ${exitCode}! Shutting down. $remoteAddress") + finish(FinalApplicationStatus.FAILED, exitCode) + } } else { logError(s"Application Master lost connection with driver! Shutting down. $remoteAddress") finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_DISCONNECTED) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 9383f21481fe4..6e6d8406049c9 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -160,9 +160,9 @@ private[spark] class YarnClientSchedulerBackend( /** * Stop the scheduler. This assumes `start()` has already been called. */ - override def stop(): Unit = { + override def stop(exitCode: Int): Unit = { assert(client != null, "Attempted to stop this scheduler before starting it!") - yarnSchedulerEndpoint.handleClientModeDriverStop() + yarnSchedulerEndpoint.handleClientModeDriverStop(exitCode) if (monitorThread != null) { monitorThread.stopMonitor() } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index b9751ebd47ae6..572c16d9e9b33 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -319,10 +319,10 @@ private[spark] abstract class YarnSchedulerBackend( removeExecutorMessage.foreach { message => driverEndpoint.send(message) } } - private[cluster] def handleClientModeDriverStop(): Unit = { + private[cluster] def handleClientModeDriverStop(exitCode: Int): Unit = { amEndpoint match { case Some(am) => - am.send(Shutdown) + am.send(Shutdown(exitCode)) case None => logWarning("Attempted to send shutdown message before the AM has registered!") } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index fcd5c56e0fac6..a29b39518c506 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -60,6 +60,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { private val continuedPrompt = "".padTo(prompt.length, ' ') private var transport: TSocket = _ private final val SPARK_HADOOP_PROP_PREFIX = "spark.hadoop." + private var exitCode = 0 initializeLogIfNecessary(true) installSignalHandler() @@ -83,6 +84,11 @@ private[hive] object SparkSQLCLIDriver extends Logging { }) } + def exit(code: Int): Unit = { + exitCode = code + System.exit(exitCode) + } + def main(args: Array[String]): Unit = { val oproc = new OptionsProcessor() if (!oproc.process_stage1(args)) { @@ -105,12 +111,12 @@ private[hive] object SparkSQLCLIDriver extends Logging { } catch { case e: UnsupportedEncodingException => sessionState.close() - System.exit(ERROR_PATH_NOT_FOUND) + exit(ERROR_PATH_NOT_FOUND) } if (!oproc.process_stage2(sessionState)) { sessionState.close() - System.exit(ERROR_MISUSE_SHELL_BUILTIN) + exit(ERROR_MISUSE_SHELL_BUILTIN) } // Set all properties specified via command line. @@ -145,7 +151,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { // Clean up after we exit ShutdownHookManager.addShutdownHook { () => sessionState.close() - SparkSQLEnv.stop() + SparkSQLEnv.stop(exitCode) } if (isRemoteMode(sessionState)) { @@ -190,7 +196,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { sessionState.info = new PrintStream(System.err, true, UTF_8.name()) sessionState.err = new PrintStream(System.err, true, UTF_8.name()) } catch { - case e: UnsupportedEncodingException => System.exit(ERROR_PATH_NOT_FOUND) + case e: UnsupportedEncodingException => exit(ERROR_PATH_NOT_FOUND) } if (sessionState.database != null) { @@ -211,17 +217,17 @@ private[hive] object SparkSQLCLIDriver extends Logging { cli.printMasterAndAppId if (sessionState.execString != null) { - System.exit(cli.processLine(sessionState.execString)) + exit(cli.processLine(sessionState.execString)) } try { if (sessionState.fileName != null) { - System.exit(cli.processFile(sessionState.fileName)) + exit(cli.processFile(sessionState.fileName)) } } catch { case e: FileNotFoundException => logError(s"Could not open input file for reading. (${e.getMessage})") - System.exit(ERROR_PATH_NOT_FOUND) + exit(ERROR_PATH_NOT_FOUND) } val reader = new ConsoleReader() @@ -303,7 +309,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { sessionState.close() - System.exit(ret) + exit(ret) } @@ -358,7 +364,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (cmd_lower.equals("quit") || cmd_lower.equals("exit")) { sessionState.close() - System.exit(EXIT_SUCCESS) + SparkSQLCLIDriver.exit(EXIT_SUCCESS) } if (tokens(0).toLowerCase(Locale.ROOT).equals("source") || cmd_trimmed.startsWith("!") || isRemoteMode) { @@ -481,7 +487,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { // Kill the VM on second ctrl+c if (!initialRequest) { console.printInfo("Exiting the JVM") - System.exit(ERROR_COMMAND_NOT_FOUND) + SparkSQLCLIDriver.exit(ERROR_COMMAND_NOT_FOUND) } // Interrupt the CLI thread to stop the current statement and return diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index fb7d32d7a2998..88a5c87eab5d9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -81,11 +81,11 @@ private[hive] object SparkSQLEnv extends Logging { } /** Cleans up and shuts down the Spark SQL environments. */ - def stop(): Unit = { + def stop(exitCode: Int = 0): Unit = { logDebug("Shutting down Spark SQL Environment") // Stop the SparkContext if (SparkSQLEnv.sparkContext != null) { - sparkContext.stop() + sparkContext.stop(exitCode) sparkContext = null sqlContext = null }