Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ package org.apache.spark.deploy.rest

import java.io.{DataOutputStream, FileNotFoundException}
import java.net.{HttpURLConnection, SocketException, URL}
import javax.servlet.http.HttpServletResponse

import scala.io.Source

import com.fasterxml.jackson.databind.JsonMappingException
import com.fasterxml.jackson.core.JsonProcessingException
import com.google.common.base.Charsets

import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
Expand Down Expand Up @@ -155,10 +156,21 @@ private[spark] class StandaloneRestClient extends Logging {
/**
* Read the response from the server and return it as a validated [[SubmitRestProtocolResponse]].
* If the response represents an error, report the embedded message to the user.
* Exposed for testing.
*/
private def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
try {
val responseJson = Source.fromInputStream(connection.getInputStream).mkString
val dataStream =
if (connection.getResponseCode == HttpServletResponse.SC_OK) {
connection.getInputStream
} else {
connection.getErrorStream
}
// If the server threw an exception while writing a response, it will not have a body
if (dataStream == null) {
throw new SubmitRestProtocolException("Server returned empty body")
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes (2) on SPARK-5760

val responseJson = Source.fromInputStream(dataStream).mkString
logDebug(s"Response from the server:\n$responseJson")
val response = SubmitRestProtocolMessage.fromJson(responseJson)
response.validate()
Expand All @@ -177,7 +189,7 @@ private[spark] class StandaloneRestClient extends Logging {
case unreachable @ (_: FileNotFoundException | _: SocketException) =>
throw new SubmitRestConnectionException(
s"Unable to connect to server ${connection.getURL}", unreachable)
case malformed @ (_: SubmitRestProtocolException | _: JsonMappingException) =>
case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) =>
throw new SubmitRestProtocolException(
"Malformed response received from server", malformed)
}
Expand Down Expand Up @@ -284,7 +296,27 @@ private[spark] object StandaloneRestClient {
val REPORT_DRIVER_STATUS_MAX_TRIES = 10
val PROTOCOL_VERSION = "v1"

/** Submit an application, assuming Spark parameters are specified through system properties. */
/**
* Submit an application, assuming Spark parameters are specified through the given config.
* This is abstracted to its own method for testing purposes.
*/
private[rest] def run(
appResource: String,
mainClass: String,
appArgs: Array[String],
conf: SparkConf,
env: Map[String, String] = sys.env): SubmitRestProtocolResponse = {
val master = conf.getOption("spark.master").getOrElse {
throw new IllegalArgumentException("'spark.master' must be set.")
}
val sparkProperties = conf.getAll.toMap
val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") }
val client = new StandaloneRestClient
val submitRequest = client.constructSubmitRequest(
appResource, mainClass, appArgs, sparkProperties, environmentVariables)
client.createSubmission(master, submitRequest)
}

def main(args: Array[String]): Unit = {
if (args.size < 2) {
sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]")
Expand All @@ -294,14 +326,6 @@ private[spark] object StandaloneRestClient {
val mainClass = args(1)
val appArgs = args.slice(2, args.size)
val conf = new SparkConf
val master = conf.getOption("spark.master").getOrElse {
throw new IllegalArgumentException("'spark.master' must be set.")
}
val sparkProperties = conf.getAll.toMap
val environmentVariables = sys.env.filter { case (k, _) => k.startsWith("SPARK_") }
val client = new StandaloneRestClient
val submitRequest = client.constructSubmitRequest(
appResource, mainClass, appArgs, sparkProperties, environmentVariables)
client.createSubmission(master, submitRequest)
run(appResource, mainClass, appArgs, conf)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

package org.apache.spark.deploy.rest

import java.io.{DataOutputStream, File}
import java.io.File
import java.net.InetSocketAddress
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}

import scala.io.Source

import akka.actor.ActorRef
import com.fasterxml.jackson.databind.JsonMappingException
import com.google.common.base.Charsets
import com.fasterxml.jackson.core.JsonProcessingException
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
import org.eclipse.jetty.util.thread.QueuedThreadPool
Expand Down Expand Up @@ -70,14 +69,14 @@ private[spark] class StandaloneRestServer(
import StandaloneRestServer._

private var _server: Option[Server] = None
private val baseContext = s"/$PROTOCOL_VERSION/submissions"

// A mapping from servlets to the URL prefixes they are responsible for
private val servletToContext = Map[StandaloneRestServlet, String](
new SubmitRequestServlet(masterActor, masterUrl, masterConf) -> s"$baseContext/create/*",
new KillRequestServlet(masterActor, masterConf) -> s"$baseContext/kill/*",
new StatusRequestServlet(masterActor, masterConf) -> s"$baseContext/status/*",
new ErrorServlet -> "/*" // default handler

// A mapping from URL prefixes to servlets that serve them. Exposed for testing.
protected val baseContext = s"/$PROTOCOL_VERSION/submissions"
protected val contextToServlet = Map[String, StandaloneRestServlet](
s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf),
s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf),
s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf),
"/*" -> new ErrorServlet // default handler
)

/** Start the server and return the bound port. */
Expand All @@ -99,7 +98,7 @@ private[spark] class StandaloneRestServer(
server.setThreadPool(threadPool)
val mainHandler = new ServletContextHandler
mainHandler.setContextPath("/")
servletToContext.foreach { case (servlet, prefix) =>
contextToServlet.foreach { case (prefix, servlet) =>
mainHandler.addServlet(new ServletHolder(servlet), prefix)
}
server.setHandler(mainHandler)
Expand All @@ -113,28 +112,15 @@ private[spark] class StandaloneRestServer(
}
}

private object StandaloneRestServer {
private[rest] object StandaloneRestServer {
val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION
val SC_UNKNOWN_PROTOCOL_VERSION = 468
}

/**
* An abstract servlet for handling requests passed to the [[StandaloneRestServer]].
*/
private abstract class StandaloneRestServlet extends HttpServlet with Logging {

/** Service a request. If an exception is thrown in the process, indicate server error. */
protected override def service(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
try {
super.service(request, response)
} catch {
case e: Exception =>
logError("Exception while handling request", e)
response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
}
}
private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging {

/**
* Serialize the given response message to JSON and send it through the response servlet.
Expand All @@ -146,11 +132,7 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging {
val message = validateResponse(responseMessage, responseServlet)
responseServlet.setContentType("application/json")
responseServlet.setCharacterEncoding("utf-8")
responseServlet.setStatus(HttpServletResponse.SC_OK)
val content = message.toJson.getBytes(Charsets.UTF_8)
val out = new DataOutputStream(responseServlet.getOutputStream)
out.write(content)
out.close()
responseServlet.getWriter.write(message.toJson)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes (1) on SPARK-5760

}

/**
Expand Down Expand Up @@ -186,6 +168,19 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging {
e
}

/**
* Parse a submission ID from the relative path, assuming it is the first part of the path.
* For instance, we expect the path to take the form /[submission ID]/maybe/something/else.
* The returned submission ID cannot be empty. If the path is unexpected, return None.
*/
protected def parseSubmissionId(path: String): Option[String] = {
if (path == null || path.isEmpty) {
None
} else {
path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty)
}
}

/**
* Validate the response to ensure that it is correctly constructed.
*
Expand All @@ -209,7 +204,7 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging {
/**
* A servlet for handling kill requests passed to the [[StandaloneRestServer]].
*/
private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
extends StandaloneRestServlet {

/**
Expand All @@ -219,18 +214,15 @@ private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
protected override def doPost(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
val submissionId = request.getPathInfo.stripPrefix("/")
val responseMessage =
if (submissionId.nonEmpty) {
handleKill(submissionId)
} else {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Submission ID is missing in kill request.")
}
val submissionId = parseSubmissionId(request.getPathInfo)
val responseMessage = submissionId.map(handleKill).getOrElse {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Submission ID is missing in kill request.")
}
sendResponse(responseMessage, response)
}

private def handleKill(submissionId: String): KillSubmissionResponse = {
protected def handleKill(submissionId: String): KillSubmissionResponse = {
val askTimeout = AkkaUtils.askTimeout(conf)
val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse](
DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout)
Expand All @@ -246,7 +238,7 @@ private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
/**
* A servlet for handling status requests passed to the [[StandaloneRestServer]].
*/
private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
extends StandaloneRestServlet {

/**
Expand All @@ -256,18 +248,15 @@ private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
protected override def doGet(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
val submissionId = request.getPathInfo.stripPrefix("/")
val responseMessage =
if (submissionId.nonEmpty) {
handleStatus(submissionId)
} else {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Submission ID is missing in status request.")
}
val submissionId = parseSubmissionId(request.getPathInfo)
val responseMessage = submissionId.map(handleStatus).getOrElse {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Submission ID is missing in status request.")
}
sendResponse(responseMessage, response)
}

private def handleStatus(submissionId: String): SubmissionStatusResponse = {
protected def handleStatus(submissionId: String): SubmissionStatusResponse = {
val askTimeout = AkkaUtils.askTimeout(conf)
val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse](
DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout)
Expand All @@ -287,7 +276,7 @@ private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
/**
* A servlet for handling submit requests passed to the [[StandaloneRestServer]].
*/
private class SubmitRequestServlet(
private[rest] class SubmitRequestServlet(
masterActor: ActorRef,
masterUrl: String,
conf: SparkConf)
Expand All @@ -313,7 +302,7 @@ private class SubmitRequestServlet(
handleSubmit(requestMessageJson, requestMessage, responseServlet)
} catch {
// The client failed to provide a valid JSON, so this is not our fault
case e @ (_: JsonMappingException | _: SubmitRestProtocolException) =>
case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) =>
responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Malformed request: " + formatException(e))
}
Expand Down Expand Up @@ -413,7 +402,7 @@ private class ErrorServlet extends StandaloneRestServlet {
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
val path = request.getPathInfo
val parts = path.stripPrefix("/").split("/").toSeq
val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes (3) on SPARK-5760

var versionMismatch = false
var msg =
parts match {
Expand All @@ -423,10 +412,10 @@ private class ErrorServlet extends StandaloneRestServlet {
case `serverVersion` :: Nil =>
// http://host:port/correct-version
"Missing the /submissions prefix."
case `serverVersion` :: "submissions" :: Nil =>
// http://host:port/correct-version/submissions
case `serverVersion` :: "submissions" :: tail =>
// http://host:port/correct-version/submissions/*
"Missing an action: please specify one of /create, /kill, or /status."
case unknownVersion :: _ =>
case unknownVersion :: tail =>
// http://host:port/unknown-version/*
versionMismatch = true
s"Unknown protocol version '$unknownVersion'."
Expand Down
Loading