From 983f490181f2ab3d67a62ecc58bf6fdba86e3e00 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 13 Jan 2017 14:30:32 -0800 Subject: [PATCH 1/4] [SPARK-19220][UI] Make redirection to HTTPS apply to all URIs. The redirect handler was installed only for the root of the server; any other context ended up being served directly through the HTTP port. Since every sub page (e.g. application UIs in the history server) is a separate servlet context, this meant that everything but the root was accessible via HTTP still. The change adds separate names to each connector, and binds contexts to specific connectors so that content is only served through the HTTPS connector when it's enabled. In that case, the only thing that binds to the HTTP connector is the redirect handler. Tested with new unit tests and by checking a live history server. --- .../scala/org/apache/spark/TestUtils.scala | 38 +++++++++- .../org/apache/spark/ui/JettyUtils.scala | 75 +++++++++++++------ .../scala/org/apache/spark/ui/WebUI.scala | 14 +--- .../org/apache/spark/ui/UISeleniumSuite.scala | 19 +---- .../scala/org/apache/spark/ui/UISuite.scala | 56 +++++++++++++- 5 files changed, 147 insertions(+), 55 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index fd0477541ef0..109104f0a537 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -18,11 +18,15 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} -import java.net.{URI, URL} +import java.net.{HttpURLConnection, URI, URL} import java.nio.charset.StandardCharsets +import java.security.SecureRandom +import java.security.cert.X509Certificate import java.util.Arrays import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} +import javax.net.ssl._ +import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -31,7 +35,6 @@ import scala.sys.process.{Process, ProcessLogger} import scala.util.Try import com.google.common.io.{ByteStreams, Files} -import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -194,6 +197,37 @@ private[spark] object TestUtils { attempt.isSuccess && attempt.get == 0 } + /** + * Returns the response code from an HTTP(S) URL. + */ + def httpResponseCode(url: URL, method: String = "GET"): Int = { + val connection = url.openConnection().asInstanceOf[HttpURLConnection] + connection.setRequestMethod(method) + + // Disable cert and host name validation for HTTPS tests. + if (connection.isInstanceOf[HttpsURLConnection]) { + val sslCtx = SSLContext.getInstance("SSL") + val trustManager = new X509TrustManager { + override def getAcceptedIssuers(): Array[X509Certificate] = null + override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} + override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} + } + val verifier = new HostnameVerifier() { + override def verify(hostname: String, session: SSLSession): Boolean = true + } + sslCtx.init(null, Array(trustManager), new SecureRandom()) + connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory()) + connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier) + } + + try { + connection.connect() + connection.getResponseCode() + } finally { + connection.disconnect() + } + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 35c3c8d00f99..638ddf83dc0a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -45,6 +45,9 @@ import org.apache.spark.util.Utils */ private[spark] object JettyUtils extends Logging { + val SPARK_CONNECTOR_NAME = "Spark" + val REDIRECT_CONNECTOR_NAME = "HttpsRedirect" + // Base type for a function that returns something based on an HTTP request. Allows for // implicit conversion from many types of functions to jetty Handlers. type Responder[T] = HttpServletRequest => T @@ -278,13 +281,15 @@ private[spark] object JettyUtils extends Logging { addFilters(handlers, conf) val gzipHandlers = handlers.map { h => + // h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME)) + val gzipHandler = new GzipHandler gzipHandler.setHandler(h) gzipHandler } // Bind to the given port, or throw a java.net.BindException if the port is occupied - def connect(currentPort: Int): (Server, Int) = { + def connect(currentPort: Int): ((Server, Option[Int]), Int) = { val pool = new QueuedThreadPool if (serverName.nonEmpty) { pool.setName(serverName) @@ -306,23 +311,31 @@ private[spark] object JettyUtils extends Logging { httpConnector.setPort(currentPort) connectors += httpConnector - sslOptions.createJettySslContextFactory().foreach { factory => - // If the new port wraps around, do not try a privileged port. - val securePort = - if (currentPort != 0) { - (currentPort + 400 - 1024) % (65536 - 1024) + 1024 - } else { - 0 - } - val scheme = "https" - // Create a connector on port securePort to listen for HTTPS requests - val connector = new ServerConnector(server, factory) - connector.setPort(securePort) - - connectors += connector - - // redirect the HTTP requests to HTTPS port - collection.addHandler(createRedirectHttpsHandler(securePort, scheme)) + val httpsConnector = sslOptions.createJettySslContextFactory() match { + case Some(factory) => + // If the new port wraps around, do not try a privileged port. + val securePort = + if (currentPort != 0) { + (currentPort + 400 - 1024) % (65536 - 1024) + 1024 + } else { + 0 + } + val scheme = "https" + // Create a connector on port securePort to listen for HTTPS requests + val connector = new ServerConnector(server, factory) + connector.setPort(securePort) + connector.setName(SPARK_CONNECTOR_NAME) + connectors += connector + + // redirect the HTTP requests to HTTPS port + httpConnector.setName(REDIRECT_CONNECTOR_NAME) + collection.addHandler(createRedirectHttpsHandler(securePort, scheme)) + Some(connector) + + case None => + // No SSL, so the HTTP connector becomes the official one where all contexts bind. + httpConnector.setName(SPARK_CONNECTOR_NAME) + None } gzipHandlers.foreach(collection.addHandler) @@ -347,7 +360,7 @@ private[spark] object JettyUtils extends Logging { server.setHandler(collection) try { server.start() - (server, httpConnector.getLocalPort) + ((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort) } catch { case e: Exception => server.stop() @@ -356,13 +369,15 @@ private[spark] object JettyUtils extends Logging { } } - val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) - ServerInfo(server, boundPort, collection) + val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf, + serverName) + ServerInfo(server, boundPort, securePort, collection) } private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = { val redirectHandler: ContextHandler = new ContextHandler redirectHandler.setContextPath("/") + redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME)) redirectHandler.setHandler(new AbstractHandler { override def handle( target: String, @@ -442,7 +457,23 @@ private[spark] object JettyUtils extends Logging { private[spark] case class ServerInfo( server: Server, boundPort: Int, - rootHandler: ContextHandlerCollection) { + securePort: Option[Int], + private val rootHandler: ContextHandlerCollection) { + + def addHandler(handler: ContextHandler): Unit = { + handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME)) + rootHandler.addHandler(handler) + if (!handler.isStarted()) { + handler.start() + } + } + + def removeHandler(handler: ContextHandler): Unit = { + rootHandler.removeHandler(handler) + if (handler.isStarted) { + handler.stop() + } + } def stop(): Unit = { server.stop() diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index b8604c52e6b0..a9480cc220c8 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -91,23 +91,13 @@ private[spark] abstract class WebUI( /** Attach a handler to this UI. */ def attachHandler(handler: ServletContextHandler) { handlers += handler - serverInfo.foreach { info => - info.rootHandler.addHandler(handler) - if (!handler.isStarted) { - handler.start() - } - } + serverInfo.foreach(_.addHandler(handler)) } /** Detach a handler from this UI. */ def detachHandler(handler: ServletContextHandler) { handlers -= handler - serverInfo.foreach { info => - info.rootHandler.removeHandler(handler) - if (handler.isStarted) { - handler.stop() - } - } + serverInfo.foreach(_.removeHandler(handler)) } /** diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index f4786e3931c9..422837303642 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -475,8 +475,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val url = new URL( sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST - getResponseCode(url, "GET") should be (200) - getResponseCode(url, "POST") should be (200) + TestUtils.httpResponseCode(url, "GET") should be (200) + TestUtils.httpResponseCode(url, "POST") should be (200) } } } @@ -488,8 +488,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val url = new URL( sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST - getResponseCode(url, "GET") should be (200) - getResponseCode(url, "POST") should be (200) + TestUtils.httpResponseCode(url, "GET") should be (200) + TestUtils.httpResponseCode(url, "POST") should be (200) } } } @@ -671,17 +671,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } - def getResponseCode(url: URL, method: String): Int = { - val connection = url.openConnection().asInstanceOf[HttpURLConnection] - connection.setRequestMethod(method) - try { - connection.connect() - connection.getResponseCode() - } finally { - connection.disconnect() - } - } - def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 68c7657cb315..5b07161cfe3f 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ui import java.net.{BindException, ServerSocket} -import java.net.URI -import javax.servlet.http.HttpServletRequest +import java.net.{URI, URL} +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source -import org.eclipse.jetty.servlet.ServletContextHandler +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.mockito.Mockito.{mock, when} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ @@ -167,6 +167,7 @@ class UISuite extends SparkFunSuite { val boundPort = serverInfo.boundPort assert(server.getState === "STARTED") assert(boundPort != 0) + assert(serverInfo.securePort.isDefined) intercept[BindException] { socket = new ServerSocket(boundPort) } @@ -227,8 +228,55 @@ class UISuite extends SparkFunSuite { assert(newHeader === null) } + test("http -> https redirect applies to all URIs") { + var serverInfo: ServerInfo = null + try { + val servlet = new HttpServlet() { + override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = { + res.sendError(HttpServletResponse.SC_OK) + } + } + + def newContext(path: String): ServletContextHandler = { + val ctx = new ServletContextHandler() + ctx.setContextPath(path) + ctx.addServlet(new ServletHolder(servlet), "/*") + ctx + } + + val (conf, sslOptions) = sslEnabledConf() + serverInfo = JettyUtils.startJettyServer( + "0.0.0.0", 0, sslOptions, Seq[ServletContextHandler](newContext("/")), conf) + assert(serverInfo.server.getState === "STARTED") + + val testContext = newContext("/test") + serverInfo.addHandler(testContext) + testContext.start() + + val httpPort = serverInfo.boundPort + + val tests = Seq( + ("http", serverInfo.boundPort, HttpServletResponse.SC_FOUND), + ("https", serverInfo.securePort.get, HttpServletResponse.SC_OK)) + + tests.foreach { case (scheme, port, expected) => + val urls = Seq( + s"$scheme://localhost:$port", + s"$scheme://localhost:$port/", + s"$scheme://localhost:$port/test", + s"$scheme://localhost:$port/test/foo") + urls.foreach { url => + val rc = TestUtils.httpResponseCode(new URL(url)) + assert(rc === expected, s"Unexpected status $rc for $url") + } + } + } finally { + stopServer(serverInfo) + } + } + def stopServer(info: ServerInfo): Unit = { - if (info != null && info.server != null) info.server.stop + if (info != null) info.stop() } def closeSocket(socket: ServerSocket): Unit = { From 67df755bc5d1f20e1406eddf0655c4abebb98003 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 13 Jan 2017 16:54:23 -0800 Subject: [PATCH 2/4] Uncomment line I disabled for debugging... --- core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 638ddf83dc0a..5f84267472c3 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -281,7 +281,7 @@ private[spark] object JettyUtils extends Logging { addFilters(handlers, conf) val gzipHandlers = handlers.map { h => - // h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME)) + h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME)) val gzipHandler = new GzipHandler gzipHandler.setHandler(h) From eb0fcb792b8130e9cbdf68eb18b15f3f49148d9b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 20 Jan 2017 15:56:12 -0800 Subject: [PATCH 3/4] Make sure the handler collection only contains handler from a successful bind. --- .../scala/org/apache/spark/ui/JettyUtils.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 5f84267472c3..f713619cd7ec 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -277,7 +277,6 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection addFilters(handlers, conf) val gzipHandlers = handlers.map { h => @@ -297,7 +296,9 @@ private[spark] object JettyUtils extends Logging { pool.setDaemon(true) val server = new Server(pool) - val connectors = new ArrayBuffer[ServerConnector] + val connectors = new ArrayBuffer[ServerConnector]() + val collection = new ContextHandlerCollection + // Create a connector on port currentPort to listen for HTTP requests val httpConnector = new ServerConnector( server, @@ -338,7 +339,6 @@ private[spark] object JettyUtils extends Logging { None } - gzipHandlers.foreach(collection.addHandler) // As each acceptor and each selector will use one thread, the number of threads should at // least be the number of acceptors and selectors plus 1. (See SPARK-13776) var minThreads = 1 @@ -350,14 +350,17 @@ private[spark] object JettyUtils extends Logging { // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 } - server.setConnectors(connectors.toArray) pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) val errorHandler = new ErrorHandler() errorHandler.setShowStacks(true) errorHandler.setServer(server) server.addBean(errorHandler) + + gzipHandlers.foreach(collection.addHandler) server.setHandler(collection) + + server.setConnectors(connectors.toArray) try { server.start() ((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort) @@ -371,7 +374,8 @@ private[spark] object JettyUtils extends Logging { val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf, serverName) - ServerInfo(server, boundPort, securePort, collection) + ServerInfo(server, boundPort, securePort, + server.getHandler().asInstanceOf[ContextHandlerCollection]) } private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = { From 5b65c697ed5f7b0ce137e8b92799bf4b24dad31b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 25 Jan 2017 12:31:26 -0800 Subject: [PATCH 4/4] Tweak test to fail without setVirtualHosts. --- .../test/scala/org/apache/spark/ui/UISuite.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 5b07161cfe3f..aa67f49185e7 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -240,16 +240,17 @@ class UISuite extends SparkFunSuite { def newContext(path: String): ServletContextHandler = { val ctx = new ServletContextHandler() ctx.setContextPath(path) - ctx.addServlet(new ServletHolder(servlet), "/*") + ctx.addServlet(new ServletHolder(servlet), "/root") ctx } val (conf, sslOptions) = sslEnabledConf() - serverInfo = JettyUtils.startJettyServer( - "0.0.0.0", 0, sslOptions, Seq[ServletContextHandler](newContext("/")), conf) + serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions, + Seq[ServletContextHandler](newContext("/"), newContext("/test1")), + conf) assert(serverInfo.server.getState === "STARTED") - val testContext = newContext("/test") + val testContext = newContext("/test2") serverInfo.addHandler(testContext) testContext.start() @@ -261,10 +262,9 @@ class UISuite extends SparkFunSuite { tests.foreach { case (scheme, port, expected) => val urls = Seq( - s"$scheme://localhost:$port", - s"$scheme://localhost:$port/", - s"$scheme://localhost:$port/test", - s"$scheme://localhost:$port/test/foo") + s"$scheme://localhost:$port/root", + s"$scheme://localhost:$port/test1/root", + s"$scheme://localhost:$port/test2/root") urls.foreach { url => val rc = TestUtils.httpResponseCode(new URL(url)) assert(rc === expected, s"Unexpected status $rc for $url")