diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d11fa554ca8c..f9f50e3b8837 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -596,8 +596,6 @@ class SparkContext(config: SparkConf) extends Logging { // The metrics system for Driver need to be set spark.app.id to app ID. // So it should start after we get app ID from the task scheduler and set spark.app.id. _env.metricsSystem.start(_conf.get(METRICS_STATIC_SOURCES_ENABLED)) - // Attach the driver metrics servlet handler to the web ui after the metrics system is started. - _env.metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) _eventLogger = if (isEventLogEnabled) { @@ -639,6 +637,11 @@ class SparkContext(config: SparkConf) extends Logging { postEnvironmentUpdate() postApplicationStart() + // After application started, attach handlers to started server and start handler. + _ui.foreach(_.attachAllHandler()) + // Attach the driver metrics servlet handler to the web ui after the metrics system is started. + _env.metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + // Post init _taskScheduler.postStartHook() if (isLocal) { diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index dcbb9baa2092..24e5534b83ae 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -34,6 +34,7 @@ import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.io.Source import scala.reflect.{classTag, ClassTag} import scala.sys.process.{Process, ProcessLogger} import scala.util.Try @@ -319,6 +320,18 @@ private[spark] object TestUtils { } } + /** + * Returns the response message from an HTTP(S) URL. + */ + def httpResponseMessage( + url: URL, + method: String = "GET", + headers: Seq[(String, String)] = Nil): String = { + withHttpConnection(url, method, headers = headers) { connection => + Source.fromInputStream(connection.getInputStream, "utf-8").getLines().mkString("\n") + } + } + def withHttpConnection[T]( url: URL, method: String = "GET", diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index b1769a8a9c9e..db1f8bc1a2ff 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -18,6 +18,9 @@ package org.apache.spark.ui import java.util.Date +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} + +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.internal.Logging @@ -54,6 +57,25 @@ private[spark] class SparkUI private ( private var streamingJobProgressListener: Option[SparkListener] = None + private val initHandler: ServletContextHandler = { + val servlet = new HttpServlet() { + override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = { + res.setContentType("text/html;charset=utf-8") + res.getWriter.write("Spark is starting up. Please wait a while until it's ready.") + } + } + createServletHandler("/", servlet, basePath) + } + + /** + * Attach all existing handlers to ServerInfo. + */ + def attachAllHandler(): Unit = { + serverInfo.foreach { server => + server.removeHandler(initHandler) + handlers.foreach(server.addHandler(_, securityManager)) + } + } /** Initialize all components of the server. */ def initialize(): Unit = { val jobsTab = new JobsTab(this, store) @@ -96,6 +118,24 @@ private[spark] class SparkUI private ( appId = id } + /** + * To start SparUI, Spark starts Jetty Server first to bind address. + * After the Spark application is fully started, call [attachAllHandlers] + * to start all existing handlers. + */ + override def bind(): Unit = { + assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") + try { + val server = initServer() + server.addHandler(initHandler, securityManager) + serverInfo = Some(server) + } catch { + case e: Exception => + logError(s"Failed to bind $className", e) + System.exit(1) + } + } + /** Stop the server behind this web interface. Only valid after bind(). */ override def stop(): Unit = { super.stop() diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index d826686382f1..80723c34eb4d 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -56,7 +56,7 @@ private[spark] abstract class WebUI( protected var serverInfo: Option[ServerInfo] = None protected val publicHostName = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( conf.get(DRIVER_HOST_ADDRESS)) - private val className = Utils.getFormattedClassName(this) + protected val className = Utils.getFormattedClassName(this) def getBasePath: String = basePath def getTabs: Seq[WebUITab] = tabs.toSeq @@ -139,15 +139,20 @@ private[spark] abstract class WebUI( /** A hook to initialize components of the UI */ def initialize(): Unit + def initServer(): ServerInfo = { + val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0") + val server = startJettyServer(host, port, sslOptions, conf, name, poolSize) + logInfo(s"Bound $className to $host, and started at $webUrl") + server + } + /** Binds to the HTTP server behind this web interface. */ def bind(): Unit = { assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!") try { - val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0") - val server = startJettyServer(host, port, sslOptions, conf, name, poolSize) + val server = initServer() handlers.foreach(server.addHandler(_, securityManager)) serverInfo = Some(server) - logInfo(s"Bound $className to $host, and started at $webUrl") } catch { case e: Exception => logError(s"Failed to bind $className", e) diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 1a7c8dad5ce7..90136dd06237 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -424,6 +424,28 @@ class UISuite extends SparkFunSuite { } } + test("SPARK-36237: Attach and start handler after application started in UI ") { + def newSparkContextWithoutUI(): SparkContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set(UI.UI_ENABLED, false) + new SparkContext(conf) + } + + withSpark(newSparkContextWithoutUI()) { sc => + assert(sc.ui.isEmpty) + val sparkUI = SparkUI.create(Some(sc), sc.statusStore, sc.conf, sc.env.securityManager, + sc.appName, "", sc.startTime) + sparkUI.bind() + assert(TestUtils.httpResponseMessage(new URL(sparkUI.webUrl + "/jobs")) + === "Spark is starting up. Please wait a while until it's ready.") + sparkUI.attachAllHandler() + assert(TestUtils.httpResponseMessage(new URL(sparkUI.webUrl + "/jobs")).contains(sc.appName)) + sparkUI.stop() + } + } + /** * Create a new context handler for the given path, with a single servlet that responds to * requests in `$path/root`.