diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index bc3b0fe73f6e..e170b9b7e30e 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -42,6 +42,22 @@
netty-all
+
+ org.fusesource.leveldbjni
+ leveldbjni-all
+ 1.8
+
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+
+ com.fasterxml.jackson.core
+ jackson-annotations
+
+
org.slf4j
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/LevelDBProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/LevelDBProvider.java
new file mode 100644
index 000000000000..ec900a7b3ca6
--- /dev/null
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/LevelDBProvider.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.fusesource.leveldbjni.JniDBFactory;
+import org.fusesource.leveldbjni.internal.NativeDB;
+import org.iq80.leveldb.DB;
+import org.iq80.leveldb.Options;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * LevelDB utility class available in the network package.
+ */
+public class LevelDBProvider {
+ private static final Logger logger = LoggerFactory.getLogger(LevelDBProvider.class);
+
+ public static DB initLevelDB(File dbFile, StoreVersion version, ObjectMapper mapper) throws
+ IOException {
+ DB tmpDb = null;
+ if (dbFile != null) {
+ Options options = new Options();
+ options.createIfMissing(false);
+ options.logger(new LevelDBLogger());
+ try {
+ tmpDb = JniDBFactory.factory.open(dbFile, options);
+ } catch (NativeDB.DBException e) {
+ if (e.isNotFound() || e.getMessage().contains(" does not exist ")) {
+ logger.info("Creating state database at " + dbFile);
+ options.createIfMissing(true);
+ try {
+ tmpDb = JniDBFactory.factory.open(dbFile, options);
+ } catch (NativeDB.DBException dbExc) {
+ throw new IOException("Unable to create state store", dbExc);
+ }
+ } else {
+ // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new
+ // one, so we can keep processing new apps
+ logger.error("error opening leveldb file {}. Creating new file, will not be able to " +
+ "recover state for existing applications", dbFile, e);
+ if (dbFile.isDirectory()) {
+ for (File f : dbFile.listFiles()) {
+ if (!f.delete()) {
+ logger.warn("error deleting {}", f.getPath());
+ }
+ }
+ }
+ if (!dbFile.delete()) {
+ logger.warn("error deleting {}", dbFile.getPath());
+ }
+ options.createIfMissing(true);
+ try {
+ tmpDb = JniDBFactory.factory.open(dbFile, options);
+ } catch (NativeDB.DBException dbExc) {
+ throw new IOException("Unable to create state store", dbExc);
+ }
+
+ }
+ }
+ // if there is a version mismatch, we throw an exception, which means the service is unusable
+ checkVersion(tmpDb, version, mapper);
+ }
+ return tmpDb;
+ }
+
+ private static class LevelDBLogger implements org.iq80.leveldb.Logger {
+ private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class);
+
+ @Override
+ public void log(String message) {
+ LOG.info(message);
+ }
+ }
+
+ /**
+ * Simple major.minor versioning scheme. Any incompatible changes should be across major
+ * versions. Minor version differences are allowed -- meaning we should be able to read
+ * dbs that are either earlier *or* later on the minor version.
+ */
+ public static void checkVersion(DB db, StoreVersion newversion, ObjectMapper mapper) throws
+ IOException {
+ byte[] bytes = db.get(StoreVersion.KEY);
+ if (bytes == null) {
+ storeVersion(db, newversion, mapper);
+ } else {
+ StoreVersion version = mapper.readValue(bytes, StoreVersion.class);
+ if (version.major != newversion.major) {
+ throw new IOException("cannot read state DB with version " + version + ", incompatible " +
+ "with current version " + newversion);
+ }
+ storeVersion(db, newversion, mapper);
+ }
+ }
+
+ public static void storeVersion(DB db, StoreVersion version, ObjectMapper mapper)
+ throws IOException {
+ db.put(StoreVersion.KEY, mapper.writeValueAsBytes(version));
+ }
+
+ public static class StoreVersion {
+
+ final static byte[] KEY = "StoreVersion".getBytes(StandardCharsets.UTF_8);
+
+ public final int major;
+ public final int minor;
+
+ @JsonCreator
+ public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) {
+ this.major = major;
+ this.minor = minor;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ StoreVersion that = (StoreVersion) o;
+
+ return major == that.major && minor == that.minor;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = major;
+ result = 31 * result + minor;
+ return result;
+ }
+ }
+}
diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml
index 2fb5835305a2..8b832cf37612 100644
--- a/common/network-shuffle/pom.xml
+++ b/common/network-shuffle/pom.xml
@@ -43,22 +43,6 @@
${project.version}
-
- org.fusesource.leveldbjni
- leveldbjni-all
- 1.8
-
-
-
- com.fasterxml.jackson.core
- jackson-databind
-
-
-
- com.fasterxml.jackson.core
- jackson-annotations
-
-
org.slf4j
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
index 000ec13f796a..e34dc1f5b1de 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
@@ -30,17 +30,16 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Objects;
import com.google.common.collect.Maps;
-import org.fusesource.leveldbjni.JniDBFactory;
-import org.fusesource.leveldbjni.internal.NativeDB;
import org.iq80.leveldb.DB;
import org.iq80.leveldb.DBIterator;
-import org.iq80.leveldb.Options;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
+import org.apache.spark.network.util.LevelDBProvider;
+import org.apache.spark.network.util.LevelDBProvider.StoreVersion;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
@@ -95,52 +94,10 @@ public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorF
Executor directoryCleaner) throws IOException {
this.conf = conf;
this.registeredExecutorFile = registeredExecutorFile;
- if (registeredExecutorFile != null) {
- Options options = new Options();
- options.createIfMissing(false);
- options.logger(new LevelDBLogger());
- DB tmpDb;
- try {
- tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
- } catch (NativeDB.DBException e) {
- if (e.isNotFound() || e.getMessage().contains(" does not exist ")) {
- logger.info("Creating state database at " + registeredExecutorFile);
- options.createIfMissing(true);
- try {
- tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
- } catch (NativeDB.DBException dbExc) {
- throw new IOException("Unable to create state store", dbExc);
- }
- } else {
- // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new
- // one, so we can keep processing new apps
- logger.error("error opening leveldb file {}. Creating new file, will not be able to " +
- "recover state for existing applications", registeredExecutorFile, e);
- if (registeredExecutorFile.isDirectory()) {
- for (File f : registeredExecutorFile.listFiles()) {
- if (!f.delete()) {
- logger.warn("error deleting {}", f.getPath());
- }
- }
- }
- if (!registeredExecutorFile.delete()) {
- logger.warn("error deleting {}", registeredExecutorFile.getPath());
- }
- options.createIfMissing(true);
- try {
- tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options);
- } catch (NativeDB.DBException dbExc) {
- throw new IOException("Unable to create state store", dbExc);
- }
-
- }
- }
- // if there is a version mismatch, we throw an exception, which means the service is unusable
- checkVersion(tmpDb);
- executors = reloadRegisteredExecutors(tmpDb);
- db = tmpDb;
+ db = LevelDBProvider.initLevelDB(this.registeredExecutorFile, CURRENT_VERSION, mapper);
+ if (db != null) {
+ executors = reloadRegisteredExecutors(db);
} else {
- db = null;
executors = Maps.newConcurrentMap();
}
this.directoryCleaner = directoryCleaner;
@@ -368,76 +325,11 @@ static ConcurrentMap reloadRegisteredExecutors(D
break;
}
AppExecId id = parseDbAppExecKey(key);
+ logger.info("Reloading registered executors: " + id.toString());
ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class);
registeredExecutors.put(id, shuffleInfo);
}
}
return registeredExecutors;
}
-
- private static class LevelDBLogger implements org.iq80.leveldb.Logger {
- private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class);
-
- @Override
- public void log(String message) {
- LOG.info(message);
- }
- }
-
- /**
- * Simple major.minor versioning scheme. Any incompatible changes should be across major
- * versions. Minor version differences are allowed -- meaning we should be able to read
- * dbs that are either earlier *or* later on the minor version.
- */
- private static void checkVersion(DB db) throws IOException {
- byte[] bytes = db.get(StoreVersion.KEY);
- if (bytes == null) {
- storeVersion(db);
- } else {
- StoreVersion version = mapper.readValue(bytes, StoreVersion.class);
- if (version.major != CURRENT_VERSION.major) {
- throw new IOException("cannot read state DB with version " + version + ", incompatible " +
- "with current version " + CURRENT_VERSION);
- }
- storeVersion(db);
- }
- }
-
- private static void storeVersion(DB db) throws IOException {
- db.put(StoreVersion.KEY, mapper.writeValueAsBytes(CURRENT_VERSION));
- }
-
-
- public static class StoreVersion {
-
- static final byte[] KEY = "StoreVersion".getBytes(StandardCharsets.UTF_8);
-
- public final int major;
- public final int minor;
-
- @JsonCreator public StoreVersion(
- @JsonProperty("major") int major,
- @JsonProperty("minor") int minor) {
- this.major = major;
- this.minor = minor;
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
-
- StoreVersion that = (StoreVersion) o;
-
- return major == that.major && minor == that.minor;
- }
-
- @Override
- public int hashCode() {
- int result = major;
- result = 31 * result + minor;
- return result;
- }
- }
-
}
diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
index b6feb55e2192..9584f075e6ae 100644
--- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
+++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
@@ -18,15 +18,28 @@
package org.apache.spark.network.yarn;
import java.io.File;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
import java.nio.ByteBuffer;
import java.util.List;
+import java.util.Map;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Objects;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.server.api.*;
+import org.apache.spark.network.util.LevelDBProvider;
+import org.iq80.leveldb.DB;
+import org.iq80.leveldb.DBIterator;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -68,6 +81,22 @@ public class YarnShuffleService extends AuxiliaryService {
private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate";
private static final boolean DEFAULT_SPARK_AUTHENTICATE = false;
+ private static final String RECOVERY_FILE_NAME = "registeredExecutors.ldb";
+ private static final String SECRETS_RECOVERY_FILE_NAME = "sparkShuffleRecovery.ldb";
+
+ // just for testing when you want to find an open port
+ @VisibleForTesting
+ static int boundPort = -1;
+ private static final ObjectMapper mapper = new ObjectMapper();
+ private static final String APP_CREDS_KEY_PREFIX = "AppCreds";
+ private static final LevelDBProvider.StoreVersion CURRENT_VERSION = new LevelDBProvider
+ .StoreVersion(1, 0);
+
+ // just for integration tests that want to look at this file -- in general not sensible as
+ // a static
+ @VisibleForTesting
+ static YarnShuffleService instance;
+
// An entity that manages the shuffle secret per application
// This is used only if authentication is enabled
private ShuffleSecretManager secretManager;
@@ -75,6 +104,8 @@ public class YarnShuffleService extends AuxiliaryService {
// The actual server that serves shuffle files
private TransportServer shuffleServer = null;
+ private Configuration _conf = null;
+
// Handles registering executors and opening shuffle blocks
@VisibleForTesting
ExternalShuffleBlockHandler blockHandler;
@@ -83,14 +114,11 @@ public class YarnShuffleService extends AuxiliaryService {
@VisibleForTesting
File registeredExecutorFile;
- // just for testing when you want to find an open port
+ // Where to store & reload application secrets for recovering state after an NM restart
@VisibleForTesting
- static int boundPort = -1;
+ File secretsFile;
- // just for integration tests that want to look at this file -- in general not sensible as
- // a static
- @VisibleForTesting
- static YarnShuffleService instance;
+ private DB db;
public YarnShuffleService() {
super("spark_shuffle");
@@ -112,42 +140,86 @@ private boolean isAuthenticationEnabled() {
*/
@Override
protected void serviceInit(Configuration conf) {
+ _conf = conf;
- // In case this NM was killed while there were running spark applications, we need to restore
- // lost state for the existing executors. We look for an existing file in the NM's local dirs.
- // If we don't find one, then we choose a file to use to save the state next time. Even if
- // an application was stopped while the NM was down, we expect yarn to call stopApplication()
- // when it comes back
- registeredExecutorFile =
- findRegisteredExecutorFile(conf.getTrimmedStrings("yarn.nodemanager.local-dirs"));
-
- TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf));
- // If authentication is enabled, set up the shuffle server to use a
- // special RPC handler that filters out unauthenticated fetch requests
- boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
try {
- blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile);
+ // In case this NM was killed while there were running spark applications, we need to restore
+ // lost state for the existing executors. We look for an existing file in the NM's local dirs.
+ // If we don't find one, then we choose a file to use to save the state next time. Even if
+ // an application was stopped while the NM was down, we expect yarn to call stopApplication()
+ // when it comes back
+ registeredExecutorFile = findRecoveryDb(RECOVERY_FILE_NAME);
+
+ TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf));
+ blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile);
+
+ // If authentication is enabled, set up the shuffle server to use a
+ // special RPC handler that filters out unauthenticated fetch requests
+ boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
+ List bootstraps = Lists.newArrayList();
+ if (authEnabled) {
+ createSecretManager();
+ bootstraps.add(new SaslServerBootstrap(transportConf, secretManager));
+ }
+
+ int port = conf.getInt(
+ SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
+ TransportContext transportContext = new TransportContext(transportConf, blockHandler);
+ shuffleServer = transportContext.createServer(port, bootstraps);
+ // the port should normally be fixed, but for tests its useful to find an open port
+ port = shuffleServer.getPort();
+ boundPort = port;
+ String authEnabledString = authEnabled ? "enabled" : "not enabled";
+ logger.info("Started YARN shuffle service for Spark on port {}. " +
+ "Authentication is {}. Registered executor file is {}", port, authEnabledString,
+ registeredExecutorFile);
} catch (Exception e) {
logger.error("Failed to initialize external shuffle service", e);
}
+ }
+
+ private void createSecretManager() throws IOException {
+ secretManager = new ShuffleSecretManager();
+ secretsFile = findRecoveryDb(SECRETS_RECOVERY_FILE_NAME);
+
+ // Make sure this is protected in case its not in the NM recovery dir
+ FileSystem fs = FileSystem.getLocal(_conf);
+ fs.mkdirs(new Path(secretsFile.getPath()), new FsPermission((short)0700));
- List bootstraps = Lists.newArrayList();
- if (authEnabled) {
- secretManager = new ShuffleSecretManager();
- bootstraps.add(new SaslServerBootstrap(transportConf, secretManager));
+ db = LevelDBProvider.initLevelDB(secretsFile, CURRENT_VERSION, mapper);
+ logger.info("Recovery location is: " + secretsFile.getPath());
+ if (db != null) {
+ logger.info("Going to reload spark shuffle data");
+ DBIterator itr = db.iterator();
+ itr.seek(APP_CREDS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8));
+ while (itr.hasNext()) {
+ Map.Entry e = itr.next();
+ String key = new String(e.getKey(), StandardCharsets.UTF_8);
+ if (!key.startsWith(APP_CREDS_KEY_PREFIX)) {
+ break;
+ }
+ String id = parseDbAppKey(key);
+ ByteBuffer secret = mapper.readValue(e.getValue(), ByteBuffer.class);
+ logger.info("Reloading tokens for app: " + id);
+ secretManager.registerApp(id, secret);
+ }
}
+ }
- int port = conf.getInt(
- SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
- TransportContext transportContext = new TransportContext(transportConf, blockHandler);
- shuffleServer = transportContext.createServer(port, bootstraps);
- // the port should normally be fixed, but for tests its useful to find an open port
- port = shuffleServer.getPort();
- boundPort = port;
- String authEnabledString = authEnabled ? "enabled" : "not enabled";
- logger.info("Started YARN shuffle service for Spark on port {}. " +
- "Authentication is {}. Registered executor file is {}", port, authEnabledString,
- registeredExecutorFile);
+ private static String parseDbAppKey(String s) throws IOException {
+ if (!s.startsWith(APP_CREDS_KEY_PREFIX)) {
+ throw new IllegalArgumentException("expected a string starting with " + APP_CREDS_KEY_PREFIX);
+ }
+ String json = s.substring(APP_CREDS_KEY_PREFIX.length() + 1);
+ AppId parsed = mapper.readValue(json, AppId.class);
+ return parsed.appId;
+ }
+
+ private static byte[] dbAppKey(AppId appExecId) throws IOException {
+ // we stick a common prefix on all the keys so we can find them in the DB
+ String appExecJson = mapper.writeValueAsString(appExecId);
+ String key = (APP_CREDS_KEY_PREFIX + ";" + appExecJson);
+ return key.getBytes(StandardCharsets.UTF_8);
}
@Override
@@ -157,6 +229,12 @@ public void initializeApplication(ApplicationInitializationContext context) {
ByteBuffer shuffleSecret = context.getApplicationDataForService();
logger.info("Initializing application {}", appId);
if (isAuthenticationEnabled()) {
+ AppId fullId = new AppId(appId);
+ if (db != null) {
+ byte[] key = dbAppKey(fullId);
+ byte[] value = mapper.writeValueAsString(shuffleSecret).getBytes(StandardCharsets.UTF_8);
+ db.put(key, value);
+ }
secretManager.registerApp(appId, shuffleSecret);
}
} catch (Exception e) {
@@ -170,6 +248,14 @@ public void stopApplication(ApplicationTerminationContext context) {
try {
logger.info("Stopping application {}", appId);
if (isAuthenticationEnabled()) {
+ AppId fullId = new AppId(appId);
+ if (db != null) {
+ try {
+ db.delete(dbAppKey(fullId));
+ } catch (IOException e) {
+ logger.error("Error deleting {} from executor state db", appId, e);
+ }
+ }
secretManager.unregisterApp(appId);
}
blockHandler.applicationRemoved(appId, false /* clean up local dirs */);
@@ -190,14 +276,15 @@ public void stopContainer(ContainerTerminationContext context) {
logger.info("Stopping container {}", containerId);
}
- private File findRegisteredExecutorFile(String[] localDirs) {
+ private File findRecoveryDb(String fileName) {
+ String[] localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs");
for (String dir: localDirs) {
- File f = new File(new Path(dir).toUri().getPath(), "registeredExecutors.ldb");
+ File f = new File(new Path(dir).toUri().getPath(), fileName);
if (f.exists()) {
return f;
}
}
- return new File(new Path(localDirs[0]).toUri().getPath(), "registeredExecutors.ldb");
+ return new File(new Path(localDirs[0]).toUri().getPath(), fileName);
}
/**
@@ -212,6 +299,9 @@ protected void serviceStop() {
if (blockHandler != null) {
blockHandler.close();
}
+ if (db != null) {
+ db.close();
+ }
} catch (Exception e) {
logger.error("Exception when stopping service", e);
}
@@ -222,4 +312,38 @@ protected void serviceStop() {
public ByteBuffer getMetaData() {
return ByteBuffer.allocate(0);
}
+
+ /**
+ * Simply encodes an application ID.
+ */
+ public static class AppId {
+ public final String appId;
+
+ @JsonCreator
+ public AppId(@JsonProperty("appId") String appId) {
+ this.appId = appId;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ AppId appExecId = (AppId) o;
+ return Objects.equal(appId, appExecId.appId);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .toString();
+ }
+ }
+
}
diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
index c33a9e6bbe25..0c2e0204356b 100644
--- a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.server.api.{ApplicationInitializationContext, ApplicationTerminationContext}
import org.scalatest.{BeforeAndAfterEach, Matchers}
+import org.apache.spark.SecurityManager
import org.apache.spark.SparkFunSuite
import org.apache.spark.network.shuffle.ShuffleTestAccessor
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
@@ -77,6 +78,8 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
test("executor state kept across NM restart") {
s1 = new YarnShuffleService
+ // set auth to true to test the secrets recovery
+ yarnConfig.setBoolean(SecurityManager.SPARK_AUTH_CONF, true)
s1.init(yarnConfig)
val app1Id = ApplicationId.newInstance(0, 1)
val app1Data: ApplicationInitializationContext =
@@ -89,6 +92,8 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
val execStateFile = s1.registeredExecutorFile
execStateFile should not be (null)
+ val secretsFile = s1.secretsFile
+ secretsFile should not be (null)
val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, SORT_MANAGER)
val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, SORT_MANAGER)
@@ -118,6 +123,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
s1.stop()
s2 = new YarnShuffleService
s2.init(yarnConfig)
+ s2.secretsFile should be (secretsFile)
s2.registeredExecutorFile should be (execStateFile)
val handler2 = s2.blockHandler
@@ -135,6 +141,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
s3 = new YarnShuffleService
s3.init(yarnConfig)
s3.registeredExecutorFile should be (execStateFile)
+ s3.secretsFile should be (secretsFile)
val handler3 = s3.blockHandler
val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3)
@@ -148,7 +155,10 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
test("removed applications should not be in registered executor file") {
s1 = new YarnShuffleService
+ yarnConfig.setBoolean(SecurityManager.SPARK_AUTH_CONF, false)
s1.init(yarnConfig)
+ val secretsFile = s1.secretsFile
+ secretsFile should be (null)
val app1Id = ApplicationId.newInstance(0, 1)
val app1Data: ApplicationInitializationContext =
new ApplicationInitializationContext("user", app1Id, null)