+ spark.yarn.clientToAM.port |
+ 0 |
+
+ Port the application master listens on for connections from the client.
+ This port is specified when registering the AM with YARN so that client can later know which
+ port to connect to from the application Report.
+ |
+
# Important notes
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index e227bff88f71..7a090d60150c 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -21,12 +21,18 @@ import java.io.{File, IOException}
import java.lang.reflect.InvocationTargetException
import java.net.{Socket, URI, URL}
import java.util.concurrent.{TimeoutException, TimeUnit}
+import javax.crypto.SecretKey
+import javax.crypto.spec.SecretKeySpec
import scala.collection.mutable.HashMap
import scala.concurrent.Promise
import scala.concurrent.duration.Duration
import scala.util.control.NonFatal
+import com.google.common.base.Charsets
+import io.netty.buffer.ByteBuf
+import io.netty.buffer.Unpooled
+import io.netty.handler.codec.base64.Base64
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
@@ -41,9 +47,16 @@ import org.apache.spark.deploy.yarn.config._
import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
+import org.apache.spark.network.{BlockDataManager, TransportContext}
+import org.apache.spark.network.client.TransportClientBootstrap
+import org.apache.spark.network.netty.{NettyBlockRpcServer, SparkTransportConf}
+import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
+import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap}
import org.apache.spark.rpc._
+import org.apache.spark.rpc.netty.NettyRpcCallContext
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
+import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util._
/**
@@ -89,6 +102,7 @@ private[spark] class ApplicationMaster(
@volatile private var reporterThread: Thread = _
@volatile private var allocator: YarnAllocator = _
+ @volatile private var clientToAMPort: Int = _
// A flag to check whether user has initialized spark context
@volatile private var registered = false
@@ -247,7 +261,9 @@ private[spark] class ApplicationMaster(
if (!unregistered) {
// we only want to unregister if we don't want the RM to retry
- if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) {
+ if (finalStatus == FinalApplicationStatus.SUCCEEDED ||
+ finalStatus == FinalApplicationStatus.KILLED ||
+ isLastAttempt) {
unregister(finalStatus, finalMsg)
cleanupStagingDir()
}
@@ -283,6 +299,7 @@ private[spark] class ApplicationMaster(
credentialRenewerThread.start()
credentialRenewerThread.join()
}
+ clientToAMPort = sparkConf.getInt("spark.yarn.clientToAM.port", 0)
if (isClusterMode) {
runDriver(securityMgr)
@@ -402,7 +419,8 @@ private[spark] class ApplicationMaster(
uiAddress,
historyAddress,
securityMgr,
- localResources)
+ localResources,
+ clientToAMPort)
// Initialize the AM endpoint *after* the allocator has been initialized. This ensures
// that when the driver sends an initial executor request (e.g. after an AM restart),
@@ -422,6 +440,35 @@ private[spark] class ApplicationMaster(
YarnSchedulerBackend.ENDPOINT_NAME)
}
+ /**
+ * Create an [[RpcEndpoint]] that communicates with the client.
+ *
+ * @return A reference to the application master's RPC endpoint.
+ */
+ private def runClientAMEndpoint(
+ port: Int,
+ driverRef: RpcEndpointRef,
+ securityManager: SecurityManager): RpcEndpointRef = {
+ val serversparkConf = new SparkConf()
+ serversparkConf.set("spark.rpc.connectionUsingTokens", "true")
+
+ val amRpcEnv =
+ RpcEnv.create(ApplicationMaster.SYSTEM_NAME, Utils.localHostName(), port, serversparkConf,
+ securityManager)
+ clientToAMPort = amRpcEnv.address.port
+
+ val clientAMEndpoint =
+ amRpcEnv.setupEndpoint(ApplicationMaster.ENDPOINT_NAME,
+ new ClientToAMEndpoint(amRpcEnv, driverRef, securityManager))
+ clientAMEndpoint
+ }
+
+ /** RpcEndpoint class for ClientToAM */
+ private[spark] class ClientToAMEndpoint(
+ override val rpcEnv: RpcEnv, driverRef: RpcEndpointRef, securityManager: SecurityManager)
+ extends RpcEndpoint with Logging {
+ }
+
private def runDriver(securityMgr: SecurityManager): Unit = {
addAmIpFilter(None)
userClassThread = startUserApplication()
@@ -438,8 +485,12 @@ private[spark] class ApplicationMaster(
val driverRef = createSchedulerRef(
sc.getConf.get("spark.driver.host"),
sc.getConf.get("spark.driver.port"))
+ val clientToAMSecurityManager = new SecurityManager(sparkConf)
+ runClientAMEndpoint(clientToAMPort, driverRef, clientToAMSecurityManager)
registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl), securityMgr)
registered = true
+ clientToAMSecurityManager.setSecretKey(Base64.encode(
+ Unpooled.wrappedBuffer(client.getMasterKey)).toString(Charsets.UTF_8));
} else {
// Sanity check; should never happen in normal operation, since sc should only be null
// if the user app did not create a SparkContext.
@@ -464,10 +515,13 @@ private[spark] class ApplicationMaster(
amCores, true)
val driverRef = waitForSparkDriver()
addAmIpFilter(Some(driverRef))
+ val clientToAMSecurityManager = new SecurityManager(sparkConf)
+ runClientAMEndpoint(clientToAMPort, driverRef, clientToAMSecurityManager)
registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"),
securityMgr)
registered = true
-
+ clientToAMSecurityManager.setSecretKey(Base64.encode(
+ Unpooled.wrappedBuffer(client.getMasterKey)).toString(Charsets.UTF_8));
// In client mode the actor will stop the reporter thread.
reporterThread.join()
}
@@ -749,8 +803,18 @@ private[spark] class ApplicationMaster(
}
+sealed trait ApplicationMasterMessage extends Serializable
+
+private [spark] object ApplicationMasterMessages {
+
+ case class HelloWorld() extends ApplicationMasterMessage
+}
+
object ApplicationMaster extends Logging {
+ val SYSTEM_NAME = "sparkYarnAM"
+ val ENDPOINT_NAME = "clientToAM"
+
// exit codes for different causes, no reason behind the values
private val EXIT_SUCCESS = 0
private val EXIT_UNCAUGHT_EXCEPTION = 10
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index d408ca90a5d1..799f34d19e73 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -18,19 +18,25 @@
package org.apache.spark.deploy.yarn
import java.io.{File, FileOutputStream, IOException, OutputStreamWriter}
-import java.net.{InetAddress, UnknownHostException, URI}
+import java.net.{InetAddress, InetSocketAddress, UnknownHostException, URI}
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.security.PrivilegedExceptionAction
import java.util.{Locale, Properties, UUID}
+import java.util.concurrent.TimeoutException
import java.util.zip.{ZipEntry, ZipOutputStream}
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map}
+import scala.concurrent.ExecutionContext
+import scala.util.control.Breaks._
import scala.util.control.NonFatal
+import com.google.common.base.Charsets.UTF_8
import com.google.common.base.Objects
import com.google.common.io.Files
+import io.netty.buffer.Unpooled
+import io.netty.handler.codec.base64.Base64
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._
import org.apache.hadoop.fs.permission.FsPermission
@@ -45,7 +51,7 @@ import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication}
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException
-import org.apache.hadoop.yarn.util.Records
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
import org.apache.spark.{SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
@@ -54,7 +60,8 @@ import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils}
-import org.apache.spark.util.{CallerContext, Utils}
+import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.util.{CallerContext, SparkExitCode, ThreadUtils, Utils}
private[spark] class Client(
val args: ClientArguments,
@@ -1149,6 +1156,43 @@ private[spark] class Client(
}
}
+ private def setupAMConnection(
+ appId: ApplicationId,
+ securityManager: SecurityManager): RpcEndpointRef = {
+ val report = getApplicationReport(appId)
+ val state = report.getYarnApplicationState
+ if (report.getHost() == null || "".equals(report.getHost())) {
+ throw new SparkException(s"AM for $appId not assigned or dont have view ACL for it")
+ }
+ if ( state != YarnApplicationState.RUNNING) {
+ throw new SparkException(s"Application $appId needs to be in RUNNING")
+ }
+
+ if (UserGroupInformation.isSecurityEnabled()) {
+ val serviceAddr = new InetSocketAddress(report.getHost(), report.getRpcPort())
+
+ val clientToAMToken = report.getClientToAMToken
+ val token = ConverterUtils.convertFromYarn(clientToAMToken, serviceAddr)
+
+ // Fetch Identifier, secretkey from the report, encode it and Set it in the Security Manager
+ val userName = token.getIdentifier
+ var userstring = Base64.encode(Unpooled.wrappedBuffer(userName)).toString(UTF_8);
+ securityManager.setSaslUser(userstring)
+ val secretkey = token.getPassword
+ var secretkeystring = Base64.encode(Unpooled.wrappedBuffer(secretkey)).toString(UTF_8);
+ securityManager.setSecretKey(secretkeystring)
+ }
+
+ sparkConf.set("spark.rpc.connectionUsingTokens", "true")
+ val rpcEnv =
+ RpcEnv.create("yarnDriverClient", Utils.localHostName(), 0, sparkConf, securityManager)
+ val AMHostPort = RpcAddress(report.getHost, report.getRpcPort)
+ val AMEndpoint = rpcEnv.setupEndpointRef(AMHostPort,
+ ApplicationMaster.ENDPOINT_NAME)
+
+ AMEndpoint
+ }
+
private def findPySparkArchives(): Seq[String] = {
sys.env.get("PYSPARK_ARCHIVES_PATH")
.map(_.split(",").toSeq)
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
index 72f4d273ab53..68c8134d5527 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -17,6 +17,8 @@
package org.apache.spark.deploy.yarn
+import java.nio.ByteBuffer
+
import scala.collection.JavaConverters._
import org.apache.hadoop.yarn.api.records._
@@ -39,6 +41,7 @@ private[spark] class YarnRMClient extends Logging {
private var amClient: AMRMClient[ContainerRequest] = _
private var uiHistoryAddress: String = _
private var registered: Boolean = false
+ private var masterkey: ByteBuffer = _
/**
* Registers the application master with the RM.
@@ -58,7 +61,8 @@ private[spark] class YarnRMClient extends Logging {
uiAddress: Option[String],
uiHistoryAddress: String,
securityMgr: SecurityManager,
- localResources: Map[String, LocalResource]
+ localResources: Map[String, LocalResource],
+ port: Int = 0
): YarnAllocator = {
amClient = AMRMClient.createAMRMClient()
amClient.init(conf)
@@ -71,8 +75,9 @@ private[spark] class YarnRMClient extends Logging {
logInfo("Registering the ApplicationMaster")
synchronized {
- amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl)
+ var response = amClient.registerApplicationMaster(Utils.localHostName(), port, trackingUrl)
registered = true
+ masterkey = response.getClientToAMTokenMasterKey()
}
new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr,
localResources, new SparkRackResolver())
@@ -89,6 +94,9 @@ private[spark] class YarnRMClient extends Logging {
amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress)
}
}
+ /** Obtain the MasterKey reported back from YARN when Registering AM. */
+ def getMasterKey(): ByteBuffer = masterkey
+
/** Returns the attempt ID. */
def getAttemptId(): ApplicationAttemptId = {
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 415a29fd887e..3df66213b29f 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -23,6 +23,8 @@ import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
import scala.util.control.NonFatal
+import org.apache.hadoop.io.DataOutputBuffer
+import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}
import org.apache.spark.SparkContext