Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions core/src/main/scala/org/apache/spark/SSLOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.spark.internal.Logging
*
* @param enabled enables or disables SSL; if it is set to false, the rest of the
* settings are disregarded
* @param port the port where to bind the SSL server; if not defined, it will be
* based on the non-SSL port for the same service.
* @param keyStore a path to the key-store file
* @param keyStorePassword a password to access the key-store file
* @param keyPassword a password to access the private key in the key-store
Expand All @@ -47,6 +49,7 @@ import org.apache.spark.internal.Logging
*/
private[spark] case class SSLOptions(
enabled: Boolean = false,
port: Option[Int] = None,
keyStore: Option[File] = None,
keyStorePassword: Option[String] = None,
keyPassword: Option[String] = None,
Expand Down Expand Up @@ -164,6 +167,11 @@ private[spark] object SSLOptions extends Logging {
def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = {
val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled))

val port = conf.getOption(s"$ns.port").map(_.toInt)
port.foreach { p =>
require(p >= 0, "Port number must be a non-negative value.")
}

val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_))
.orElse(defaults.flatMap(_.keyStore))

Expand Down Expand Up @@ -198,6 +206,7 @@ private[spark] object SSLOptions extends Logging {

new SSLOptions(
enabled,
port,
keyStore,
keyStorePassword,
keyPassword,
Expand Down
187 changes: 102 additions & 85 deletions core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.xml.Node

import org.eclipse.jetty.client.api.Response
import org.eclipse.jetty.proxy.ProxyServlet
import org.eclipse.jetty.server.{HttpConnectionFactory, Request, Server, ServerConnector}
import org.eclipse.jetty.server._
import org.eclipse.jetty.server.handler._
import org.eclipse.jetty.servlet._
import org.eclipse.jetty.servlets.gzip.GzipHandler
Expand Down Expand Up @@ -279,109 +279,125 @@ private[spark] object JettyUtils extends Logging {

addFilters(handlers, conf)

val gzipHandlers = handlers.map { h =>
h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME))

val gzipHandler = new GzipHandler
gzipHandler.setHandler(h)
gzipHandler
// Start the server first, with no connectors.
val pool = new QueuedThreadPool
if (serverName.nonEmpty) {
pool.setName(serverName)
}
pool.setDaemon(true)

// Bind to the given port, or throw a java.net.BindException if the port is occupied
def connect(currentPort: Int): ((Server, Option[Int]), Int) = {
val pool = new QueuedThreadPool
if (serverName.nonEmpty) {
pool.setName(serverName)
}
pool.setDaemon(true)

val server = new Server(pool)
val connectors = new ArrayBuffer[ServerConnector]()
val collection = new ContextHandlerCollection

// Create a connector on port currentPort to listen for HTTP requests
val httpConnector = new ServerConnector(
server,
null,
// Call this full constructor to set this, which forces daemon threads:
new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true),
null,
-1,
-1,
new HttpConnectionFactory())
httpConnector.setPort(currentPort)
connectors += httpConnector

val httpsConnector = sslOptions.createJettySslContextFactory() match {
case Some(factory) =>
// If the new port wraps around, do not try a privileged port.
val securePort =
if (currentPort != 0) {
(currentPort + 400 - 1024) % (65536 - 1024) + 1024
} else {
0
}
val scheme = "https"
// Create a connector on port securePort to listen for HTTPS requests
val connector = new ServerConnector(server, factory)
connector.setPort(securePort)
connector.setName(SPARK_CONNECTOR_NAME)
connectors += connector

// redirect the HTTP requests to HTTPS port
httpConnector.setName(REDIRECT_CONNECTOR_NAME)
collection.addHandler(createRedirectHttpsHandler(securePort, scheme))
Some(connector)
val server = new Server(pool)

case None =>
// No SSL, so the HTTP connector becomes the official one where all contexts bind.
httpConnector.setName(SPARK_CONNECTOR_NAME)
None
}
val errorHandler = new ErrorHandler()
errorHandler.setShowStacks(true)
errorHandler.setServer(server)
server.addBean(errorHandler)

val collection = new ContextHandlerCollection
server.setHandler(collection)

// Executor used to create daemon threads for the Jetty connectors.
val serverExecutor = new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true)

try {
server.start()

// As each acceptor and each selector will use one thread, the number of threads should at
// least be the number of acceptors and selectors plus 1. (See SPARK-13776)
var minThreads = 1
connectors.foreach { connector =>

def newConnector(
connectionFactories: Array[ConnectionFactory],
port: Int): (ServerConnector, Int) = {
val connector = new ServerConnector(
server,
null,
serverExecutor,
null,
-1,
-1,
connectionFactories: _*)
connector.setPort(port)
connector.start()

// Currently we only use "SelectChannelConnector"
// Limit the max acceptor number to 8 so that we don't waste a lot of threads
connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8))
connector.setHost(hostName)
// The number of selectors always equals to the number of acceptors
minThreads += connector.getAcceptors * 2

(connector, connector.getLocalPort())
}
pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))

val errorHandler = new ErrorHandler()
errorHandler.setShowStacks(true)
errorHandler.setServer(server)
server.addBean(errorHandler)

gzipHandlers.foreach(collection.addHandler)
server.setHandler(collection)

server.setConnectors(connectors.toArray)
try {
server.start()
((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort)
} catch {
case e: Exception =>
server.stop()
pool.stop()
throw e
// If SSL is configured, create the secure connector first.
val securePort = sslOptions.createJettySslContextFactory().map { factory =>
val securePort = sslOptions.port.getOrElse(if (port > 0) Utils.userPort(port, 400) else 0)
val secureServerName = if (serverName.nonEmpty) s"$serverName (HTTPS)" else serverName
val connectionFactories = AbstractConnectionFactory.getFactories(factory,
new HttpConnectionFactory())

def sslConnect(currentPort: Int): (ServerConnector, Int) = {
newConnector(connectionFactories, currentPort)
}

val (connector, boundPort) = Utils.startServiceOnPort[ServerConnector](securePort,
sslConnect, conf, secureServerName)
connector.setName(SPARK_CONNECTOR_NAME)
server.addConnector(connector)
boundPort
}
}

val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf,
serverName)
ServerInfo(server, boundPort, securePort,
server.getHandler().asInstanceOf[ContextHandlerCollection])
// Bind the HTTP port.
def httpConnect(currentPort: Int): (ServerConnector, Int) = {
newConnector(Array(new HttpConnectionFactory()), currentPort)
}

val (httpConnector, httpPort) = Utils.startServiceOnPort[ServerConnector](port, httpConnect,
conf, serverName)

// If SSL is configured, then configure redirection in the HTTP connector.
securePort match {
case Some(p) =>
httpConnector.setName(REDIRECT_CONNECTOR_NAME)
val redirector = createRedirectHttpsHandler(p, "https")
collection.addHandler(redirector)
redirector.start()

case None =>
httpConnector.setName(SPARK_CONNECTOR_NAME)
}

server.addConnector(httpConnector)

// Add all the known handlers now that connectors are configured.
handlers.foreach { h =>
h.setVirtualHosts(toVirtualHosts(SPARK_CONNECTOR_NAME))
val gzipHandler = new GzipHandler()
gzipHandler.setHandler(h)
collection.addHandler(gzipHandler)
gzipHandler.start()
}

pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))
ServerInfo(server, httpPort, securePort, collection)
} catch {
case e: Exception =>
server.stop()
if (serverExecutor.isStarted()) {
serverExecutor.stop()
}
if (pool.isStarted()) {
pool.stop()
}
throw e
}
}

private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch.

val redirectHandler: ContextHandler = new ContextHandler
redirectHandler.setContextPath("/")
redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME))
redirectHandler.setVirtualHosts(toVirtualHosts(REDIRECT_CONNECTOR_NAME))
redirectHandler.setHandler(new AbstractHandler {
override def handle(
target: String,
Expand All @@ -394,8 +410,7 @@ private[spark] object JettyUtils extends Logging {
val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, securePort,
baseRequest.getRequestURI, baseRequest.getQueryString)
response.setContentLength(0)
response.encodeRedirectURL(httpsURI)
response.sendRedirect(httpsURI)
response.sendRedirect(response.encodeRedirectURL(httpsURI))
baseRequest.setHandled(true)
}
})
Expand Down Expand Up @@ -456,6 +471,8 @@ private[spark] object JettyUtils extends Logging {
new URI(scheme, authority, path, query, null).toString
}

def toVirtualHosts(connectors: String*): Array[String] = connectors.map("@" + _).toArray

}

private[spark] case class ServerInfo(
Expand All @@ -465,7 +482,7 @@ private[spark] case class ServerInfo(
private val rootHandler: ContextHandlerCollection) {

def addHandler(handler: ContextHandler): Unit = {
handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME))
handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME))
rootHandler.addHandler(handler)
if (!handler.isStarted()) {
handler.start()
Expand Down
11 changes: 9 additions & 2 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,14 @@ private[spark] object Utils extends Logging {
}
}

/**
* Returns the user port to try when trying to bind a service. Handles wrapping and skipping
* privileged ports.
*/
def userPort(base: Int, offset: Int): Int = {
(base + offset - 1024) % (65536 - 1024) + 1024
}

/**
* Attempt to start a service on the given port, or fail after a number of attempts.
* Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0).
Expand Down Expand Up @@ -2229,8 +2237,7 @@ private[spark] object Utils extends Logging {
val tryPort = if (startPort == 0) {
startPort
} else {
// If the new port wraps around, do not try a privilege port
((startPort + offset - 1024) % (65536 - 1024)) + 1024
userPort(startPort, offset)
}
try {
val (service, port) = startService(tryPort)
Expand Down
2 changes: 2 additions & 0 deletions core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll {
val conf = new SparkConf
conf.set("spark.ssl.enabled", "true")
conf.set("spark.ssl.ui.enabled", "false")
conf.set("spark.ssl.ui.port", "4242")
conf.set("spark.ssl.keyStore", keyStorePath)
conf.set("spark.ssl.keyStorePassword", "password")
conf.set("spark.ssl.ui.keyStorePassword", "12345")
Expand All @@ -118,6 +119,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll {
val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts))

assert(opts.enabled === false)
assert(opts.port === Some(4242))
assert(opts.trustStore.isDefined === true)
assert(opts.trustStore.get.getName === "truststore")
assert(opts.trustStore.get.getAbsolutePath === trustStorePath)
Expand Down
28 changes: 27 additions & 1 deletion core/src/test/scala/org/apache/spark/ui/UISuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._

import org.apache.spark._
import org.apache.spark.LocalSparkContext._
import org.apache.spark.util.Utils

class UISuite extends SparkFunSuite {

Expand All @@ -52,13 +53,16 @@ class UISuite extends SparkFunSuite {
(conf, new SecurityManager(conf).getSSLOptions("ui"))
}

private def sslEnabledConf(): (SparkConf, SSLOptions) = {
private def sslEnabledConf(sslPort: Option[Int] = None): (SparkConf, SSLOptions) = {
val keyStoreFilePath = getTestResourcePath("spark.keystore")
val conf = new SparkConf()
.set("spark.ssl.ui.enabled", "true")
.set("spark.ssl.ui.keyStore", keyStoreFilePath)
.set("spark.ssl.ui.keyStorePassword", "123456")
.set("spark.ssl.ui.keyPassword", "123456")
sslPort.foreach { p =>
conf.set("spark.ssl.ui.port", p.toString)
}
(conf, new SecurityManager(conf).getSSLOptions("ui"))
}

Expand Down Expand Up @@ -275,6 +279,28 @@ class UISuite extends SparkFunSuite {
}
}

test("specify both http and https ports separately") {
var socket: ServerSocket = null
var serverInfo: ServerInfo = null
try {
socket = new ServerSocket(0)

// Make sure the SSL port lies way outside the "http + 400" range used as the default.
val baseSslPort = Utils.userPort(socket.getLocalPort(), 10000)
val (conf, sslOptions) = sslEnabledConf(sslPort = Some(baseSslPort))

serverInfo = JettyUtils.startJettyServer("0.0.0.0", socket.getLocalPort() + 1,
sslOptions, Seq[ServletContextHandler](), conf, "server1")

val notAllowed = Utils.userPort(serverInfo.boundPort, 400)
assert(serverInfo.securePort.isDefined)
assert(serverInfo.securePort.get != Utils.userPort(serverInfo.boundPort, 400))
} finally {
stopServer(serverInfo)
closeSocket(socket)
}
}

def stopServer(info: ServerInfo): Unit = {
if (info != null) info.stop()
}
Expand Down
14 changes: 14 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,20 @@ Apart from these, the following properties are also available, and may be useful
Configuration</a> for details on hierarchical SSL configuration for services.
</td>
</tr>
<tr>
<td><code>spark.ssl.[namespace].port</code></td>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about saying spark.ssl.port instead to be consistent with any other property related to SSL?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was intentional. "spark.ssl.port" doesn't make that much sense if you think about it; you want things like the master UI and history server UI to have different, well known ports, so having this shared config key here doesn't make a lot of sense. For the other configs, such as algorithms and keystore locations, sharing configuration is ok.

<td>None</td>
<td>
The port where the SSL service will listen on.

<br />The port must be defined within a namespace configuration; see
<a href="security.html#ssl-configuration">SSL Configuration</a> for the available
namespaces.

<br />When not set, the SSL port will be derived from the non-SSL port for the
same service. A value of "0" will make the service bind to an ephemeral port.
</td>
</tr>
<tr>
<td><code>spark.ssl.enabledAlgorithms</code></td>
<td>Empty</td>
Expand Down
Loading