diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index ccd8504ba0d7..56435a706bf5 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -61,6 +61,15 @@ jackson-annotations + + org.apache.hadoop + hadoop-client + + + org.apache.hadoop + hadoop-yarn-common + + org.slf4j diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8f354ad78bba..d8697287d285 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -75,6 +75,7 @@ public class TransportClient implements Closeable { private final Channel channel; private final TransportResponseHandler handler; @Nullable private String clientId; + @Nullable private String clientUser; private volatile boolean timedOut; public TransportClient(Channel channel, TransportResponseHandler handler) { @@ -114,6 +115,25 @@ public void setClientId(String id) { this.clientId = id; } + /** + * Returns the user name used by the client to authenticate itself when authentication is enabled. + * + * @return The client User Name, or null if authentication is disabled. + */ + public String getClientUser() { + return clientUser; + } + + /** + * Sets the authenticated client's user name. This is meant to be used by the authentication layer. + * + * Trying to set a different client User Name after it's been set will result in an exception. + */ + public void setClientUser(String user) { + Preconditions.checkState(clientUser == null, "Client User Name has already been set."); + this.clientUser = user; + } + /** * Requests a single chunk from the remote side, from the pre-negotiated streamId. * diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java index 3c263783a610..e80f84f30135 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java @@ -95,8 +95,9 @@ public void doBootstrap(TransportClient client, Channel channel) { private void doSparkAuth(TransportClient client, Channel channel) throws GeneralSecurityException, IOException { + String user = secretKeyHolder.getSaslUser(appId); String secretKey = secretKeyHolder.getSecretKey(appId); - try (AuthEngine engine = new AuthEngine(appId, secretKey, conf)) { + try (AuthEngine engine = new AuthEngine(appId, user, secretKey, conf)) { ClientChallenge challenge = engine.challenge(); ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength()); challenge.encode(challengeData); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java index b769ebeba36c..10e8e8570a2f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -39,11 +39,17 @@ import org.apache.commons.crypto.cipher.CryptoCipherFactory; import org.apache.commons.crypto.random.CryptoRandom; import org.apache.commons.crypto.random.CryptoRandomFactory; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenIdentifier; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenSecretManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.util.TransportConf; +import static org.apache.spark.network.util.HadoopSecurityUtils.decodeMasterKey; +import static org.apache.spark.network.util.HadoopSecurityUtils.getClientToAMSecretKey; +import static org.apache.spark.network.util.HadoopSecurityUtils.getIdentifier; + /** * A helper class for abstracting authentication and key negotiation details. This is used by * both client and server sides, since the operations are basically the same. @@ -54,11 +60,13 @@ class AuthEngine implements Closeable { private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 }); private final byte[] appId; - private final char[] secret; + private final byte[] user; + private char[] secret; private final TransportConf conf; private final Properties cryptoConf; private final CryptoRandom random; + private String clientUser; private byte[] authNonce; @VisibleForTesting @@ -69,13 +77,25 @@ class AuthEngine implements Closeable { private CryptoCipher decryptor; AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException { + this(appId, "",secret, conf); + } + + AuthEngine(String appId, String user, String secret, TransportConf conf) throws GeneralSecurityException { this.appId = appId.getBytes(UTF_8); + this.user = user.getBytes(UTF_8); this.conf = conf; this.cryptoConf = conf.cryptoConf(); this.secret = secret.toCharArray(); this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf); } + /** + * Returns the user name of the client. + */ + public String getClientUserName() { + return clientUser; + } + /** * Create the client challenge. * @@ -89,6 +109,7 @@ ClientChallenge challenge() throws GeneralSecurityException, IOException { this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); return new ClientChallenge(new String(appId, UTF_8), + new String(user, UTF_8), conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(), conf.cipherTransformation(), @@ -106,9 +127,22 @@ ClientChallenge challenge() throws GeneralSecurityException, IOException { */ ServerResponse respond(ClientChallenge clientChallenge) throws GeneralSecurityException, IOException { + SecretKeySpec authKey; + if (conf.isConnectionUsingTokens()) { + // Create a secret from client's token identifier and AM's master key. + ClientToAMTokenSecretManager secretManager = new ClientToAMTokenSecretManager(null, + decodeMasterKey(new String(secret))); + ClientToAMTokenIdentifier identifier = getIdentifier(clientChallenge.user); + secret = getClientToAMSecretKey(identifier, secretManager); + + clientUser = identifier.getUser().getShortUserName(); + } else { + clientUser = clientChallenge.user; + } + + authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, clientChallenge.nonce, + clientChallenge.keyLength); - SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, - clientChallenge.nonce, clientChallenge.keyLength); initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey); byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge); @@ -119,6 +153,7 @@ ServerResponse respond(ClientChallenge clientChallenge) SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, sessionNonce, clientChallenge.keyLength); + this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey, inputIv, outputIv); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 8a6e3858081b..b50e9ff1a10a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -114,12 +114,14 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb // Here we have the client challenge, so perform the new auth protocol and set up the channel. AuthEngine engine = null; try { + String user = secretKeyHolder.getSaslUser(challenge.appId); String secret = secretKeyHolder.getSecretKey(challenge.appId); Preconditions.checkState(secret != null, "Trying to authenticate non-registered app %s.", challenge.appId); LOG.debug("Authenticating challenge for app {}.", challenge.appId); - engine = new AuthEngine(challenge.appId, secret, conf); + engine = new AuthEngine(challenge.appId, user, secret, conf); ServerResponse response = engine.respond(challenge); + client.setClientUser(engine.getClientUserName()); ByteBuf responseData = Unpooled.buffer(response.encodedLength()); response.encode(responseData); callback.onSuccess(responseData.nioBuffer()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java index 819b8a7efbdb..a6fff5e4b2a7 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java @@ -35,6 +35,7 @@ public class ClientChallenge implements Encodable { private static final byte TAG_BYTE = (byte) 0xFA; public final String appId; + public final String user; public final String kdf; public final int iterations; public final String cipher; @@ -42,8 +43,19 @@ public class ClientChallenge implements Encodable { public final byte[] nonce; public final byte[] challenge; + public ClientChallenge( + String appId, + String kdf, + int iterations, + String cipher, + int keyLength, + byte[] nonce, + byte[] challenge) { + this(appId, "", kdf, iterations, cipher, keyLength, nonce, challenge); + } public ClientChallenge( String appId, + String user, String kdf, int iterations, String cipher, @@ -51,6 +63,7 @@ public ClientChallenge( byte[] nonce, byte[] challenge) { this.appId = appId; + this.user = user; this.kdf = kdf; this.iterations = iterations; this.cipher = cipher; @@ -63,6 +76,7 @@ public ClientChallenge( public int encodedLength() { return 1 + 4 + 4 + Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(user) + Encoders.Strings.encodedLength(kdf) + Encoders.Strings.encodedLength(cipher) + Encoders.ByteArrays.encodedLength(nonce) + @@ -73,6 +87,7 @@ public int encodedLength() { public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, user); Encoders.Strings.encode(buf, kdf); buf.writeInt(iterations); Encoders.Strings.encode(buf, cipher); @@ -89,6 +104,7 @@ public static ClientChallenge decodeMessage(ByteBuffer buffer) { } return new ClientChallenge( + Encoders.Strings.decode(buf), Encoders.Strings.decode(buf), Encoders.Strings.decode(buf), buf.readInt(), diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 647813772294..5b95001b0f07 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -57,7 +57,7 @@ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder sec */ @Override public void doBootstrap(TransportClient client, Channel channel) { - SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption()); + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption(), conf); try { byte[] payload = saslClient.firstToken(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 0231428318ad..721bd3089b51 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -95,7 +95,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb // First message in the handshake, setup the necessary state. client.setClientId(saslMessage.appId); saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, - conf.saslServerAlwaysEncrypt()); + conf.saslServerAlwaysEncrypt(), conf); } byte[] response; @@ -114,6 +114,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb // method returns. This assumes that the code ensures, through other means, that no outbound // messages are being written to the channel while negotiation is still going on. if (saslServer.isComplete()) { + client.setClientUser(saslServer.getUserName()); if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { logger.debug("SASL authentication successful for channel {}", client); complete(true); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index b6256debb8e3..5d984a77885d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -35,8 +35,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.util.TransportConf; + import static org.apache.spark.network.sasl.SparkSaslServer.*; + /** * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the * initial state to the "authenticated" state. This client initializes the protocol via a @@ -48,12 +51,25 @@ public class SparkSaslClient implements SaslEncryptionBackend { private final String secretKeyId; private final SecretKeyHolder secretKeyHolder; private final String expectedQop; + private TransportConf conf; private SaslClient saslClient; - public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) { + public SparkSaslClient( + String secretKeyId, + SecretKeyHolder secretKeyHolder, + boolean alwaysEncrypt) { + this(secretKeyId,secretKeyHolder,alwaysEncrypt, null); + } + + public SparkSaslClient( + String secretKeyId, + SecretKeyHolder secretKeyHolder, + boolean encrypt, + TransportConf conf) { this.secretKeyId = secretKeyId; this.secretKeyHolder = secretKeyHolder; this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH; + this.conf = conf; Map saslProps = ImmutableMap.builder() .put(Sasl.QOP, expectedQop) @@ -131,11 +147,23 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback if (callback instanceof NameCallback) { logger.trace("SASL client callback: setting username"); NameCallback nc = (NameCallback) callback; - nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + if (conf != null && conf.isConnectionUsingTokens()) { + // Token Identifier is already encoded + nc.setName(secretKeyHolder.getSaslUser(secretKeyId)); + } else { + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } + } else if (callback instanceof PasswordCallback) { logger.trace("SASL client callback: setting password"); PasswordCallback pc = (PasswordCallback) callback; - pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + if (conf != null && conf.isConnectionUsingTokens()) { + // Token Identifier is already encoded + pc.setPassword(secretKeyHolder.getSecretKey(secretKeyId).toCharArray()); + } else { + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + + } } else if (callback instanceof RealmCallback) { logger.trace("SASL client callback: setting realm"); RealmCallback rc = (RealmCallback) callback; diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index 00f3e83dbc8b..f836bfbf9e91 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -40,6 +40,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenSecretManager; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenIdentifier; + +import static org.apache.spark.network.util.HadoopSecurityUtils.decodeMasterKey; +import static org.apache.spark.network.util.HadoopSecurityUtils.getClientToAMSecretKey; +import static org.apache.spark.network.util.HadoopSecurityUtils.getIdentifier; + +import org.apache.spark.network.util.TransportConf; + /** * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the * initial state to the "authenticated" state. (It is not a server in the sense of accepting @@ -73,14 +82,25 @@ public class SparkSaslServer implements SaslEncryptionBackend { /** Identifier for a certain secret key within the secretKeyHolder. */ private final String secretKeyId; private final SecretKeyHolder secretKeyHolder; + private TransportConf conf; + private String clientUser; private SaslServer saslServer; public SparkSaslServer( String secretKeyId, SecretKeyHolder secretKeyHolder, boolean alwaysEncrypt) { + this(secretKeyId, secretKeyHolder, alwaysEncrypt, null); + } + + public SparkSaslServer( + String secretKeyId, + SecretKeyHolder secretKeyHolder, + boolean alwaysEncrypt, + TransportConf conf) { this.secretKeyId = secretKeyId; this.secretKeyHolder = secretKeyHolder; + this.conf = conf; // Sasl.QOP is a comma-separated list of supported values. The value that allows encryption // is listed first since it's preferred over the non-encrypted one (if the client also @@ -98,6 +118,13 @@ public SparkSaslServer( } } + /** + * Returns the user name of the client. + */ + public String getUserName() { + return clientUser; + } + /** * Determines whether the authentication exchange has completed successfully. */ @@ -156,15 +183,16 @@ public byte[] unwrap(byte[] data, int offset, int len) throws SaslException { private class DigestCallbackHandler implements CallbackHandler { @Override public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + NameCallback nc = null; + PasswordCallback pc = null; for (Callback callback : callbacks) { if (callback instanceof NameCallback) { logger.trace("SASL server callback: setting username"); - NameCallback nc = (NameCallback) callback; + nc = (NameCallback) callback; nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); } else if (callback instanceof PasswordCallback) { logger.trace("SASL server callback: setting password"); - PasswordCallback pc = (PasswordCallback) callback; - pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + pc = (PasswordCallback) callback; } else if (callback instanceof RealmCallback) { logger.trace("SASL server callback: setting realm"); RealmCallback rc = (RealmCallback) callback; @@ -182,10 +210,21 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); } } + if (pc != null) { + if (conf != null && conf.isConnectionUsingTokens()) { + ClientToAMTokenSecretManager secretManager = new ClientToAMTokenSecretManager(null, + decodeMasterKey(secretKeyHolder.getSecretKey(secretKeyId))); + ClientToAMTokenIdentifier identifier = getIdentifier(nc.getDefaultName()); + clientUser = identifier.getUser().getShortUserName(); + pc.setPassword(getClientToAMSecretKey(identifier, secretManager)); + } else { + clientUser = nc.getDefaultName(); + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } + } } } - - /* Encode a byte[] identifier as a Base64-encoded string. */ + /** Encode a String identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); return getBase64EncodedString(identifier); diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/HadoopSecurityUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/HadoopSecurityUtils.java new file mode 100644 index 000000000000..cffceb53a736 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/HadoopSecurityUtils.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; + +import org.apache.hadoop.security.token.SecretManager.InvalidToken; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenSecretManager; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenIdentifier; + +/** + * Utility methods related to the hadoop security + */ +public class HadoopSecurityUtils { + + /** Creates an ClientToAMTokenIdentifier from the encoded Base-64 String */ + public static ClientToAMTokenIdentifier getIdentifier(String id) throws InvalidToken { + byte[] tokenId = byteBufToByte(Base64.decode( + Unpooled.wrappedBuffer(id.getBytes(StandardCharsets.UTF_8)))); + + ClientToAMTokenIdentifier tokenIdentifier = new ClientToAMTokenIdentifier(); + try { + tokenIdentifier.readFields(new DataInputStream(new ByteArrayInputStream(tokenId))); + } catch (IOException e) { + throw (InvalidToken) new InvalidToken( + "Can't de-serialize tokenIdentifier").initCause(e); + } + return tokenIdentifier; + } + + /** Returns an Base64-encoded secretKey created from the Identifier and the secretmanager */ + public static char[] getClientToAMSecretKey(ClientToAMTokenIdentifier tokenid, + ClientToAMTokenSecretManager secretManager) throws InvalidToken { + byte[] password = secretManager.retrievePassword(tokenid); + return Base64.encode(Unpooled.wrappedBuffer(password)).toString(StandardCharsets.UTF_8) + .toCharArray(); + } + + /** Decode a base64-encoded MasterKey as a byte[] array. */ + public static byte[] decodeMasterKey(String masterKey) { + ByteBuf masterKeyByteBuf = Base64.decode(Unpooled.wrappedBuffer(masterKey.getBytes(StandardCharsets.UTF_8))); + return byteBufToByte(masterKeyByteBuf); + } + + /** Convert an ByteBuf to a byte[] array. */ + private static byte[] byteBufToByte(ByteBuf byteBuf) { + byte[] byteArray = new byte[byteBuf.readableBytes()]; + byteBuf.readBytes(byteArray); + return byteArray; + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 88256b810bf0..d0b975bfdde6 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -108,6 +108,9 @@ public int numConnectionsPerPeer() { /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); } + /** If true, the current RPC connection is a Client to AM connection */ + public boolean isConnectionUsingTokens() { return conf.getBoolean("spark.rpc.connectionUsingTokens", false); } + /** * Receive buffer size (SO_RCVBUF). * Note: the optimal size for receive buffer and send buffer should be diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 2480e56b72cc..9a9518837cb6 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -222,10 +222,11 @@ private[spark] class SecurityManager( setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) - setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); - setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")); + setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")) + setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")) - private val secretKey = generateSecretKey() + private var identifier = "sparkSaslUser" + private var secretKey = generateSecretKey() logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + "; users with view permissions: " + viewAcls.toString() + @@ -533,11 +534,23 @@ private[spark] class SecurityManager( /** * Gets the user used for authenticating SASL connections. - * For now use a single hardcoded user. * @return the SASL user as a String */ - def getSaslUser(): String = "sparkSaslUser" + def getSaslUser(): String = identifier + + /** + * This can be a user name or unique identifier + */ + def setSaslUser(ident: String) { + identifier = ident + } + /** + * set the secret key + */ + def setSecretKey(secret: String) { + secretKey = secret + } /** * Gets the secret key. * @return the secret key as a String if authentication is enabled, otherwise returns null diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index 117f51c5b8f2..32f84b95593b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -35,7 +35,12 @@ private[spark] trait RpcCallContext { def sendFailure(e: Throwable): Unit /** - * The sender of this message. + * The sender's address of this message. */ def senderAddress: RpcAddress + + /** + * The sender's User Name of this message. + */ + def senderUserName: String } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 904c4d02dd2a..f6c7f7e30b24 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -121,8 +121,8 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte /** Posts a message sent by a remote endpoint. */ def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { - val rpcCallContext = - new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress) + val rpcCallContext = new RemoteNettyRpcCallContext(nettyEnv, callback, + message.senderAddress, message.senderUserName) val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e)) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index 7dd7e610a28e..8c2d9bae00a9 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -23,7 +23,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc.{RpcAddress, RpcCallContext} -private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress) +private[netty] abstract class NettyRpcCallContext( + override val senderAddress: RpcAddress, + override val senderUserName: String = null) extends RpcCallContext with Logging { protected def send(message: Any): Unit @@ -57,8 +59,9 @@ private[netty] class LocalNettyRpcCallContext( private[netty] class RemoteNettyRpcCallContext( nettyEnv: NettyRpcEnv, callback: RpcResponseCallback, - senderAddress: RpcAddress) - extends NettyRpcCallContext(senderAddress) { + senderAddress: RpcAddress, + senderUserName: String) + extends NettyRpcCallContext(senderAddress, senderUserName) { override protected def send(message: Any): Unit = { val reply = nettyEnv.serialize(message) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 1777e7a53975..a298a4cd7029 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -543,7 +543,8 @@ private[netty] class NettyRpcEndpointRef( private[netty] class RequestMessage( val senderAddress: RpcAddress, val receiver: NettyRpcEndpointRef, - val content: Any) { + val content: Any, + val senderUserName: String = null) { /** Manually serialize [[RequestMessage]] to minimize the size. */ def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = { @@ -589,7 +590,11 @@ private[netty] object RequestMessage { } } - def apply(nettyEnv: NettyRpcEnv, client: TransportClient, bytes: ByteBuffer): RequestMessage = { + def apply( + nettyEnv: NettyRpcEnv, + client: TransportClient, + bytes: ByteBuffer, + senderUserName: String = null): RequestMessage = { val bis = new ByteBufferInputStream(bytes) val in = new DataInputStream(bis) try { @@ -601,7 +606,8 @@ private[netty] object RequestMessage { senderAddress, ref, // The remaining bytes in `bytes` are the message content. - nettyEnv.deserialize(client, bytes)) + nettyEnv.deserialize(client, bytes), + senderUserName) } finally { in.close() } @@ -652,10 +658,12 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostString, addr.getPort) - val requestMessage = RequestMessage(nettyEnv, client, message) + var requestMessage = RequestMessage(nettyEnv, client, message, client.getClientUser) + if (requestMessage.senderAddress == null) { // Create a new message with the socket address of the client as the sender. - new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) + new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content, + client.getClientUser) } else { // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for // the listening address diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e4a74556d4f2..4e3d67a07373 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -453,6 +453,15 @@ To use a custom metrics.properties for the application master and executors, upd name matches both the include and the exclude pattern, this file will be excluded eventually. + + spark.yarn.clientToAM.port + 0 + + Port the application master listens on for connections from the client. + This port is specified when registering the AM with YARN so that client can later know which + port to connect to from the application Report. + + # Important notes diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e227bff88f71..7a090d60150c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -21,12 +21,18 @@ import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException import java.net.{Socket, URI, URL} import java.util.concurrent.{TimeoutException, TimeUnit} +import javax.crypto.SecretKey +import javax.crypto.spec.SecretKeySpec import scala.collection.mutable.HashMap import scala.concurrent.Promise import scala.concurrent.duration.Duration import scala.util.control.NonFatal +import com.google.common.base.Charsets +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled +import io.netty.handler.codec.base64.Base64 import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ @@ -41,9 +47,16 @@ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.network.{BlockDataManager, TransportContext} +import org.apache.spark.network.client.TransportClientBootstrap +import org.apache.spark.network.netty.{NettyBlockRpcServer, SparkTransportConf} +import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} +import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap} import org.apache.spark.rpc._ +import org.apache.spark.rpc.netty.NettyRpcCallContext import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util._ /** @@ -89,6 +102,7 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ + @volatile private var clientToAMPort: Int = _ // A flag to check whether user has initialized spark context @volatile private var registered = false @@ -247,7 +261,9 @@ private[spark] class ApplicationMaster( if (!unregistered) { // we only want to unregister if we don't want the RM to retry - if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { + if (finalStatus == FinalApplicationStatus.SUCCEEDED || + finalStatus == FinalApplicationStatus.KILLED || + isLastAttempt) { unregister(finalStatus, finalMsg) cleanupStagingDir() } @@ -283,6 +299,7 @@ private[spark] class ApplicationMaster( credentialRenewerThread.start() credentialRenewerThread.join() } + clientToAMPort = sparkConf.getInt("spark.yarn.clientToAM.port", 0) if (isClusterMode) { runDriver(securityMgr) @@ -402,7 +419,8 @@ private[spark] class ApplicationMaster( uiAddress, historyAddress, securityMgr, - localResources) + localResources, + clientToAMPort) // Initialize the AM endpoint *after* the allocator has been initialized. This ensures // that when the driver sends an initial executor request (e.g. after an AM restart), @@ -422,6 +440,35 @@ private[spark] class ApplicationMaster( YarnSchedulerBackend.ENDPOINT_NAME) } + /** + * Create an [[RpcEndpoint]] that communicates with the client. + * + * @return A reference to the application master's RPC endpoint. + */ + private def runClientAMEndpoint( + port: Int, + driverRef: RpcEndpointRef, + securityManager: SecurityManager): RpcEndpointRef = { + val serversparkConf = new SparkConf() + serversparkConf.set("spark.rpc.connectionUsingTokens", "true") + + val amRpcEnv = + RpcEnv.create(ApplicationMaster.SYSTEM_NAME, Utils.localHostName(), port, serversparkConf, + securityManager) + clientToAMPort = amRpcEnv.address.port + + val clientAMEndpoint = + amRpcEnv.setupEndpoint(ApplicationMaster.ENDPOINT_NAME, + new ClientToAMEndpoint(amRpcEnv, driverRef, securityManager)) + clientAMEndpoint + } + + /** RpcEndpoint class for ClientToAM */ + private[spark] class ClientToAMEndpoint( + override val rpcEnv: RpcEnv, driverRef: RpcEndpointRef, securityManager: SecurityManager) + extends RpcEndpoint with Logging { + } + private def runDriver(securityMgr: SecurityManager): Unit = { addAmIpFilter(None) userClassThread = startUserApplication() @@ -438,8 +485,12 @@ private[spark] class ApplicationMaster( val driverRef = createSchedulerRef( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port")) + val clientToAMSecurityManager = new SecurityManager(sparkConf) + runClientAMEndpoint(clientToAMPort, driverRef, clientToAMSecurityManager) registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl), securityMgr) registered = true + clientToAMSecurityManager.setSecretKey(Base64.encode( + Unpooled.wrappedBuffer(client.getMasterKey)).toString(Charsets.UTF_8)); } else { // Sanity check; should never happen in normal operation, since sc should only be null // if the user app did not create a SparkContext. @@ -464,10 +515,13 @@ private[spark] class ApplicationMaster( amCores, true) val driverRef = waitForSparkDriver() addAmIpFilter(Some(driverRef)) + val clientToAMSecurityManager = new SecurityManager(sparkConf) + runClientAMEndpoint(clientToAMPort, driverRef, clientToAMSecurityManager) registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"), securityMgr) registered = true - + clientToAMSecurityManager.setSecretKey(Base64.encode( + Unpooled.wrappedBuffer(client.getMasterKey)).toString(Charsets.UTF_8)); // In client mode the actor will stop the reporter thread. reporterThread.join() } @@ -749,8 +803,18 @@ private[spark] class ApplicationMaster( } +sealed trait ApplicationMasterMessage extends Serializable + +private [spark] object ApplicationMasterMessages { + + case class HelloWorld() extends ApplicationMasterMessage +} + object ApplicationMaster extends Logging { + val SYSTEM_NAME = "sparkYarnAM" + val ENDPOINT_NAME = "clientToAM" + // exit codes for different causes, no reason behind the values private val EXIT_SUCCESS = 0 private val EXIT_UNCAUGHT_EXCEPTION = 10 diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index d408ca90a5d1..799f34d19e73 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -18,19 +18,25 @@ package org.apache.spark.deploy.yarn import java.io.{File, FileOutputStream, IOException, OutputStreamWriter} -import java.net.{InetAddress, UnknownHostException, URI} +import java.net.{InetAddress, InetSocketAddress, UnknownHostException, URI} import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.security.PrivilegedExceptionAction import java.util.{Locale, Properties, UUID} +import java.util.concurrent.TimeoutException import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} +import scala.concurrent.ExecutionContext +import scala.util.control.Breaks._ import scala.util.control.NonFatal +import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects import com.google.common.io.Files +import io.netty.buffer.Unpooled +import io.netty.handler.codec.base64.Base64 import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission @@ -45,7 +51,7 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException -import org.apache.hadoop.yarn.util.Records +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil @@ -54,7 +60,8 @@ import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} -import org.apache.spark.util.{CallerContext, Utils} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.util.{CallerContext, SparkExitCode, ThreadUtils, Utils} private[spark] class Client( val args: ClientArguments, @@ -1149,6 +1156,43 @@ private[spark] class Client( } } + private def setupAMConnection( + appId: ApplicationId, + securityManager: SecurityManager): RpcEndpointRef = { + val report = getApplicationReport(appId) + val state = report.getYarnApplicationState + if (report.getHost() == null || "".equals(report.getHost())) { + throw new SparkException(s"AM for $appId not assigned or dont have view ACL for it") + } + if ( state != YarnApplicationState.RUNNING) { + throw new SparkException(s"Application $appId needs to be in RUNNING") + } + + if (UserGroupInformation.isSecurityEnabled()) { + val serviceAddr = new InetSocketAddress(report.getHost(), report.getRpcPort()) + + val clientToAMToken = report.getClientToAMToken + val token = ConverterUtils.convertFromYarn(clientToAMToken, serviceAddr) + + // Fetch Identifier, secretkey from the report, encode it and Set it in the Security Manager + val userName = token.getIdentifier + var userstring = Base64.encode(Unpooled.wrappedBuffer(userName)).toString(UTF_8); + securityManager.setSaslUser(userstring) + val secretkey = token.getPassword + var secretkeystring = Base64.encode(Unpooled.wrappedBuffer(secretkey)).toString(UTF_8); + securityManager.setSecretKey(secretkeystring) + } + + sparkConf.set("spark.rpc.connectionUsingTokens", "true") + val rpcEnv = + RpcEnv.create("yarnDriverClient", Utils.localHostName(), 0, sparkConf, securityManager) + val AMHostPort = RpcAddress(report.getHost, report.getRpcPort) + val AMEndpoint = rpcEnv.setupEndpointRef(AMHostPort, + ApplicationMaster.ENDPOINT_NAME) + + AMEndpoint + } + private def findPySparkArchives(): Seq[String] = { sys.env.get("PYSPARK_ARCHIVES_PATH") .map(_.split(",").toSeq) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 72f4d273ab53..68c8134d5527 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy.yarn +import java.nio.ByteBuffer + import scala.collection.JavaConverters._ import org.apache.hadoop.yarn.api.records._ @@ -39,6 +41,7 @@ private[spark] class YarnRMClient extends Logging { private var amClient: AMRMClient[ContainerRequest] = _ private var uiHistoryAddress: String = _ private var registered: Boolean = false + private var masterkey: ByteBuffer = _ /** * Registers the application master with the RM. @@ -58,7 +61,8 @@ private[spark] class YarnRMClient extends Logging { uiAddress: Option[String], uiHistoryAddress: String, securityMgr: SecurityManager, - localResources: Map[String, LocalResource] + localResources: Map[String, LocalResource], + port: Int = 0 ): YarnAllocator = { amClient = AMRMClient.createAMRMClient() amClient.init(conf) @@ -71,8 +75,9 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl) + var response = amClient.registerApplicationMaster(Utils.localHostName(), port, trackingUrl) registered = true + masterkey = response.getClientToAMTokenMasterKey() } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, localResources, new SparkRackResolver()) @@ -89,6 +94,9 @@ private[spark] class YarnRMClient extends Logging { amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) } } + /** Obtain the MasterKey reported back from YARN when Registering AM. */ + def getMasterKey(): ByteBuffer = masterkey + /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId = { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 415a29fd887e..3df66213b29f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -23,6 +23,8 @@ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} import scala.util.control.NonFatal +import org.apache.hadoop.io.DataOutputBuffer +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} import org.apache.spark.SparkContext