diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java index c3068c936b410..ed73864dc37d2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDriverComponents.java @@ -22,12 +22,6 @@ public interface ShuffleDriverComponents { - enum MapOutputUnregistrationStrategy { - MAP_OUTPUT_ONLY, - EXECUTOR, - HOST, - } - /** * @return additional SparkConf values necessary for the executors. */ @@ -39,7 +33,18 @@ default void registerShuffle(int shuffleId) throws IOException {} default void removeShuffle(int shuffleId, boolean blocking) throws IOException {} - default MapOutputUnregistrationStrategy unregistrationStrategyOnFetchFailure() { - return MapOutputUnregistrationStrategy.EXECUTOR; + /** + * Indicates whether or not the data stored for the given map output is available outside + * of the host of the mapper executor. + * + * @return true if it can be verified that the map output is stored outside of the mapper + * AND if the map output is available in such an external location; false otherwise. + */ + default boolean checkIfMapOutputStoredOutsideExecutor(int shuffleId, int mapId) { + return false; + } + + default boolean unregisterOutputOnHostOnFetchFailure() { + return true; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java index 77fcd34f962bf..9b068e56cc5da 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java @@ -37,7 +37,7 @@ public LocalDiskShuffleDataIO(SparkConf sparkConf) { @Override public ShuffleDriverComponents driver() { - return new LocalDiskShuffleDriverComponents(); + return new LocalDiskShuffleDriverComponents(sparkConf); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java index 080bba49a4dc0..1df45140aad36 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/LocalDiskShuffleDriverComponents.java @@ -19,25 +19,39 @@ import java.util.Map; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; import org.apache.spark.shuffle.api.ShuffleDriverComponents; -import org.apache.spark.internal.config.package$; import org.apache.spark.storage.BlockManagerMaster; public class LocalDiskShuffleDriverComponents implements ShuffleDriverComponents { private BlockManagerMaster blockManagerMaster; - private boolean shouldUnregisterOutputOnHostOnFetchFailure; + private final SparkConf sparkConf; + + public LocalDiskShuffleDriverComponents(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @VisibleForTesting + public LocalDiskShuffleDriverComponents(BlockManagerMaster blockManagerMaster) { + this.sparkConf = new SparkConf(false); + this.blockManagerMaster = blockManagerMaster; + } + + @VisibleForTesting + public LocalDiskShuffleDriverComponents( + SparkConf sparkConf, BlockManagerMaster blockManagerMaster) { + this.sparkConf = sparkConf; + this.blockManagerMaster = blockManagerMaster; + } @Override public Map initializeApplication() { blockManagerMaster = SparkEnv.get().blockManager().master(); - this.shouldUnregisterOutputOnHostOnFetchFailure = - SparkEnv.get().blockManager().externalShuffleServiceEnabled() - && (boolean) SparkEnv.get().conf() - .get(package$.MODULE$.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE()); return ImmutableMap.of(); } @@ -48,11 +62,20 @@ public void removeShuffle(int shuffleId, boolean blocking) { } @Override - public MapOutputUnregistrationStrategy unregistrationStrategyOnFetchFailure() { - if (shouldUnregisterOutputOnHostOnFetchFailure) { - return MapOutputUnregistrationStrategy.HOST; - } - return MapOutputUnregistrationStrategy.EXECUTOR; + public boolean unregisterOutputOnHostOnFetchFailure() { + boolean unregisterOutputOnHostOnFetchFailure = Boolean.parseBoolean( + sparkConf.get( + org.apache.spark.internal.config.package$.MODULE$ + .UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE().key(), + org.apache.spark.internal.config.package$.MODULE$ + .UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE().defaultValueString())); + boolean externalShuffleServiceEnabled = Boolean.parseBoolean( + sparkConf.get( + org.apache.spark.internal.config.package$.MODULE$ + .SHUFFLE_SERVICE_ENABLED().key(), + org.apache.spark.internal.config.package$.MODULE$ + .SHUFFLE_SERVICE_ENABLED().defaultValueString())); + return unregisterOutputOnHostOnFetchFailure && externalShuffleServiceEnabled; } private void checkInitialized() { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a169192225d7e..673bfa138c144 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -36,6 +36,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException +import org.apache.spark.shuffle.api.ShuffleDriverComponents import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockAttemptId, ShuffleBlockId} import org.apache.spark.util._ @@ -125,31 +126,15 @@ private class ShuffleStatus(numPartitions: Int) { } } - /** - * Removes all shuffle outputs associated with this host. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists). - */ - def removeOutputsOnHost(host: String): Unit = { - removeOutputsByFilter(x => x.host == host) - } - - /** - * Removes all map outputs associated with the specified executor. Note that this will also - * remove outputs which are served by an external shuffle server (if one exists), as they are - * still registered with that execId. - */ - def removeOutputsOnExecutor(execId: String): Unit = synchronized { - removeOutputsByFilter(x => x.executorId == execId) - } - /** * Removes all shuffle outputs which satisfies the filter. Note that this will also * remove outputs which are served by an external shuffle server (if one exists). */ - def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized { + type MapId = Int + def removeOutputsByFilter(f: (MapId, BlockManagerId) => Boolean): Unit = synchronized { for (mapId <- 0 until mapStatuses.length) { if (mapStatuses(mapId) != null && mapStatuses(mapId).location != null - && f(mapStatuses(mapId).location)) { + && f(mapId, mapStatuses(mapId).location)) { decrementNumAvailableOutputs(mapStatuses(mapId).location) mapStatuses(mapId) = null invalidateSerializedMapOutputStatusCache() @@ -368,7 +353,8 @@ private[spark] object ExecutorShuffleStatus extends Enumeration { private[spark] class MapOutputTrackerMaster( conf: SparkConf, broadcastManager: BroadcastManager, - isLocal: Boolean) + isLocal: Boolean, + val shuffleDriverComponents: ShuffleDriverComponents) extends MapOutputTracker(conf) { // The size at which we use Broadcast to send the map output statuses to the executors @@ -487,7 +473,7 @@ private[spark] class MapOutputTrackerMaster( def unregisterAllMapOutput(shuffleId: Int) { shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => - shuffleStatus.removeOutputsByFilter(x => true) + shuffleStatus.removeOutputsByFilter((x, y) => true) incrementEpoch() case None => throw new SparkException( @@ -527,7 +513,13 @@ private[spark] class MapOutputTrackerMaster( * outputs which are served by an external shuffle server (if one exists). */ def removeOutputsOnHost(host: String): Unit = { - shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) } + shuffleStatuses.foreach { case (shuffleId, shuffleStatus) => + shuffleStatus.removeOutputsByFilter( + (mapId, location) => { + location.host == host && + !shuffleDriverComponents.checkIfMapOutputStoredOutsideExecutor(shuffleId, mapId) + }) + } incrementEpoch() } @@ -537,7 +529,13 @@ private[spark] class MapOutputTrackerMaster( * registered with this execId. */ def removeOutputsOnExecutor(execId: String): Unit = { - shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) } + shuffleStatuses.foreach { case (shuffleId, shuffleStatus) => + shuffleStatus.removeOutputsByFilter( + (mapId, location) => { + location.executorId == execId && + !shuffleDriverComponents.checkIfMapOutputStoredOutsideExecutor(shuffleId, mapId) + }) + } incrementEpoch() } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c1ee4f46e1499..4d2e677945fe2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -495,11 +495,7 @@ class SparkContext(config: SparkConf) extends SafeLogging { executorEnvs ++= _conf.getExecutorEnv executorEnvs("SPARK_USER") = sparkUser - val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS) - val maybeIO = Utils.loadExtensions( - classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) - require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") - _shuffleDriverComponents = maybeIO.head.driver() + _shuffleDriverComponents = _env.shuffleDataIo.driver() _shuffleDriverComponents.initializeApplication().asScala.foreach { case (k, v) => _conf.set(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX + k, v) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a81a6b97e1212..a88242336daf6 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -42,7 +42,8 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{ShuffleDataIOUtils, ShuffleManager} +import org.apache.spark.shuffle.api.ShuffleDataIO import org.apache.spark.storage._ import org.apache.spark.util.{RpcUtils, Utils} @@ -71,6 +72,7 @@ class SparkEnv ( val metricsSystem: MetricsSystem, val memoryManager: MemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, + val shuffleDataIo: ShuffleDataIO, val conf: SparkConf) extends Logging { private[spark] var isStopped = false @@ -340,8 +342,10 @@ object SparkEnv extends Logging { val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) + val shuffleDataIo = ShuffleDataIOUtils.loadShuffleDataIO(conf) + val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf, broadcastManager, isLocal) + new MapOutputTrackerMaster(conf, broadcastManager, isLocal, shuffleDataIo.driver()) } else { new MapOutputTrackerWorker(conf) } @@ -419,6 +423,7 @@ object SparkEnv extends Logging { metricsSystem, memoryManager, outputCommitCoordinator, + shuffleDataIo, conf) // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4403d6a5f08e5..9f242585e9ddb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -43,7 +43,8 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout -import org.apache.spark.shuffle.api.ShuffleDriverComponents.MapOutputUnregistrationStrategy +import org.apache.spark.shuffle.api.ShuffleDriverComponents +import org.apache.spark.shuffle.sort.lifecycle.LocalDiskShuffleDriverComponents import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ @@ -121,17 +122,19 @@ private[spark] class DAGScheduler( mapOutputTracker: MapOutputTrackerMaster, blockManagerMaster: BlockManagerMaster, env: SparkEnv, + shuffleDriverComponents: ShuffleDriverComponents, clock: Clock = new SystemClock()) extends Logging { def this(sc: SparkContext, taskScheduler: TaskScheduler) = { this( - sc, - taskScheduler, - sc.listenerBus, - sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], - sc.env.blockManager.master, - sc.env) + sc = sc, + taskScheduler = taskScheduler, + listenerBus = sc.listenerBus, + mapOutputTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], + blockManagerMaster = sc.env.blockManager.master, + env = sc.env, + shuffleDriverComponents = sc.env.shuffleDataIo.driver()) } def this(sc: SparkContext) = this(sc, sc.taskScheduler) @@ -171,8 +174,6 @@ private[spark] class DAGScheduler( private[scheduler] val activeJobs = new HashSet[ActiveJob] - private[scheduler] val shuffleDriverComponents = sc.shuffleDriverComponents - /** * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids * and its values are arrays indexed by partition numbers. Each array value is the set of @@ -199,14 +200,6 @@ private[spark] class DAGScheduler( /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY) - /** - * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, - * this is set default to false, which means, we only unregister the outputs related to the exact - * executor(instead of the host) on a FetchFailure. - */ - private[scheduler] val unRegisterOutputOnHostOnFetchFailure = - sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) - /** * Number of consecutive stage attempts allowed before a stage is aborted. */ @@ -1673,8 +1666,7 @@ private[spark] class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { if (bmAddress.executorId == null) { - if (shuffleDriverComponents.unregistrationStrategyOnFetchFailure() == - MapOutputUnregistrationStrategy.HOST) { + if (shuffleDriverComponents.unregisterOutputOnHostOnFetchFailure()) { val currentEpoch = task.epoch val host = bmAddress.host logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch)) @@ -1683,10 +1675,7 @@ private[spark] class DAGScheduler( } } else { val hostToUnregisterOutputs = - if (shuffleDriverComponents.unregistrationStrategyOnFetchFailure() == - MapOutputUnregistrationStrategy.HOST) { - // We had a fetch failure with the external shuffle service, so we - // assume all shuffle data on the node is bad. + if (shuffleDriverComponents.unregisterOutputOnHostOnFetchFailure()) { Some(bmAddress.host) } else { // Unregister shuffle data just for one executor (we don't have any @@ -1866,11 +1855,8 @@ private[spark] class DAGScheduler( logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch)) mapOutputTracker.removeOutputsOnHost(host) case None => - if (shuffleDriverComponents.unregistrationStrategyOnFetchFailure() == - MapOutputUnregistrationStrategy.EXECUTOR) { - logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) - mapOutputTracker.removeOutputsOnExecutor(execId) - } + logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) + mapOutputTracker.removeOutputsOnExecutor(execId) } clearCacheLocs() diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleDataIOUtils.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleDataIOUtils.scala new file mode 100644 index 0000000000000..faee5a500855e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleDataIOUtils.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.SparkConf +import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS +import org.apache.spark.shuffle.api.ShuffleDataIO +import org.apache.spark.util.Utils + +private[spark] object ShuffleDataIOUtils { + + /** + * The prefix of spark config keys that are passed from the driver to the executor. + */ + val SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin." + + def loadShuffleDataIO(conf: SparkConf): ShuffleDataIO = { + val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS) + val maybeIO = Utils.loadExtensions( + classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) + require(maybeIO.nonEmpty, s"At least one valid shuffle plugin must be specified by config " + + s"${SHUFFLE_IO_PLUGIN_CLASS.key}, but $configuredPluginClasses resulted in zero valid " + + s"plugins.") + maybeIO.head + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 610c04ace3b6f..df5a66ff8751f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -22,10 +22,9 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.Logging import org.apache.spark.shuffle._ import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents} -import org.apache.spark.util.Utils /** * In sort-based shuffle, incoming records are sorted according to their target partition ids, then @@ -221,13 +220,10 @@ private[spark] object SortShuffleManager extends Logging { } private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { - val configuredPluginClasses = conf.get(config.SHUFFLE_IO_PLUGIN_CLASS) - val maybeIO = Utils.loadExtensions( - classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) - require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") - val executorComponents = maybeIO.head.executor() val extraConfigs = conf.getAllWithPrefix(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX) - .toMap + .toMap + val env = SparkEnv.get + val executorComponents = env.shuffleDataIo.executor() executorComponents.initializeExecutor( conf.getAppId, SparkEnv.get.executorId, diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 2b1258d7c923a..4b347ad835dd8 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -29,7 +29,8 @@ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MA import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockAttemptId, ShuffleBlockId} +import org.apache.spark.shuffle.api.ShuffleDriverComponents +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockAttemptId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf @@ -37,7 +38,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { private def newTrackerMaster(sparkConf: SparkConf = conf) = { val broadcastManager = new BroadcastManager(true, sparkConf, new SecurityManager(sparkConf)) - new MapOutputTrackerMaster(sparkConf, broadcastManager, true) + val driverComponents = mock(classOf[ShuffleDriverComponents]) + when(driverComponents.checkIfMapOutputStoredOutsideExecutor(any(), any())).thenReturn(false) + new MapOutputTrackerMaster(sparkConf, broadcastManager, true, driverComponents) } def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala index cafdfacb4dc16..ac06157c8c8e3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerShufflePluginSuite.scala @@ -16,72 +16,27 @@ */ package org.apache.spark.scheduler -import java.util +import java.util.{Collections, Map => JMap} import org.apache.spark.{FetchFailed, HashPartitioner, ShuffleDependency, SparkConf, Success} -import org.apache.spark.internal.config import org.apache.spark.rdd.RDD -import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents} -import org.apache.spark.shuffle.api.ShuffleDriverComponents.MapOutputUnregistrationStrategy -import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO +import org.apache.spark.shuffle.api.ShuffleDriverComponents import org.apache.spark.storage.BlockManagerId -class PluginShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { - val localDiskShuffleDataIO = new LocalDiskShuffleDataIO(sparkConf) - override def driver(): ShuffleDriverComponents = - new PluginShuffleDriverComponents(localDiskShuffleDataIO.driver()) +class PluginShuffleDriverComponents extends ShuffleDriverComponents { - override def executor(): ShuffleExecutorComponents = localDiskShuffleDataIO.executor() -} - -class PluginShuffleDriverComponents(delegate: ShuffleDriverComponents) - extends ShuffleDriverComponents { - override def initializeApplication(): util.Map[String, String] = - delegate.initializeApplication() - - override def cleanupApplication(): Unit = - delegate.cleanupApplication() - - override def removeShuffle(shuffleId: Int, blocking: Boolean): Unit = - delegate.removeShuffle(shuffleId, blocking) - - override def unregistrationStrategyOnFetchFailure(): - ShuffleDriverComponents.MapOutputUnregistrationStrategy = - MapOutputUnregistrationStrategy.HOST -} - -class AsyncShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { - val localDiskShuffleDataIO = new LocalDiskShuffleDataIO(sparkConf) - override def driver(): ShuffleDriverComponents = - new AsyncShuffleDriverComponents(localDiskShuffleDataIO.driver()) - - override def executor(): ShuffleExecutorComponents = localDiskShuffleDataIO.executor() -} - -class AsyncShuffleDriverComponents(delegate: ShuffleDriverComponents) - extends ShuffleDriverComponents { - override def initializeApplication(): util.Map[String, String] = - delegate.initializeApplication() - - override def cleanupApplication(): Unit = - delegate.cleanupApplication() - - override def removeShuffle(shuffleId: Int, blocking: Boolean): Unit = - delegate.removeShuffle(shuffleId, blocking) + override def initializeApplication(): JMap[String, String] = Collections.emptyMap() - override def unregistrationStrategyOnFetchFailure(): - ShuffleDriverComponents.MapOutputUnregistrationStrategy = - MapOutputUnregistrationStrategy.MAP_OUTPUT_ONLY + override def unregisterOutputOnHostOnFetchFailure(): Boolean = true } class DAGSchedulerShufflePluginSuite extends DAGSchedulerSuite { - private def setupTest(pluginClass: Class[_]): (RDD[_], Int) = { + private def setupTest(): (RDD[_], Int) = { afterEach() val conf = new SparkConf() // unregistering all outputs on a host is enabled for the individual file server case - conf.set(config.SHUFFLE_IO_PLUGIN_CLASS, pluginClass.getName) - init(conf) + init(conf, (_, _) => new PluginShuffleDriverComponents) val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId @@ -89,32 +44,8 @@ class DAGSchedulerShufflePluginSuite extends DAGSchedulerSuite { (reduceRdd, shuffleId) } - test("Test async") { - val (reduceRdd, shuffleId) = setupTest(classOf[AsyncShuffleDataIO]) - submit(reduceRdd, Array(0, 1)) - - // Perform map task - val mapStatus1 = makeMapStatus("exec1", "hostA") - val mapStatus2 = makeMapStatus("exec1", "hostA") - complete(taskSets(0), Seq((Success, mapStatus1), (Success, mapStatus2))) - assertMapShuffleLocations(shuffleId, Seq(mapStatus1, mapStatus2)) - - - // perform reduce task - complete(taskSets(1), Seq((Success, 42), - (FetchFailed(BlockManagerId("exec1", "hostA", 1234), shuffleId, 1, 0, "ignored"), null))) - assertMapShuffleLocations(shuffleId, Seq(mapStatus1, null)) - - scheduler.resubmitFailedStages() - complete(taskSets(2), Seq((Success, mapStatus2))) - - complete(taskSets(3), Seq((Success, 43))) - assert(results === Map(0 -> 42, 1 -> 43)) - assertDataStructuresEmpty() - } - test("Test simple file server") { - val (reduceRdd, shuffleId) = setupTest(classOf[PluginShuffleDataIO]) + val (reduceRdd, shuffleId) = setupTest() submit(reduceRdd, Array(0, 1)) // Perform map task @@ -130,7 +61,7 @@ class DAGSchedulerShufflePluginSuite extends DAGSchedulerSuite { } test("Test simple file server fetch failure") { - val (reduceRdd, shuffleId) = setupTest(classOf[PluginShuffleDataIO]) + val (reduceRdd, shuffleId) = setupTest() submit(reduceRdd, Array(0, 1)) // Perform map task @@ -151,7 +82,7 @@ class DAGSchedulerShufflePluginSuite extends DAGSchedulerSuite { } test("Test simple file fetch server - duplicate host") { - val (reduceRdd, shuffleId) = setupTest(classOf[PluginShuffleDataIO]) + val (reduceRdd, shuffleId) = setupTest() submit(reduceRdd, Array(0, 1)) // Perform map task @@ -172,7 +103,7 @@ class DAGSchedulerShufflePluginSuite extends DAGSchedulerSuite { } test("Test DFS case - empty BlockManagerId") { - val (reduceRdd, shuffleId) = setupTest(classOf[PluginShuffleDataIO]) + val (reduceRdd, shuffleId) = setupTest() submit(reduceRdd, Array(0, 1)) val mapStatus = makeEmptyMapStatus() @@ -186,7 +117,7 @@ class DAGSchedulerShufflePluginSuite extends DAGSchedulerSuite { } test("Test DFS case - fetch failure") { - val (reduceRdd, shuffleId) = setupTest(classOf[PluginShuffleDataIO]) + val (reduceRdd, shuffleId) = setupTest() submit(reduceRdd, Array(0, 1)) // Perform map task @@ -215,6 +146,6 @@ class DAGSchedulerShufflePluginSuite extends DAGSchedulerSuite { def assertMapShuffleLocations(shuffleId: Int, set: Seq[MapStatus]): Unit = { val actualShuffleLocations = mapOutputTracker.shuffleStatuses(shuffleId).mapStatuses - assert(set === actualShuffleLocations.toSeq) + assert(actualShuffleLocations.toSeq === set) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8e8b90600d16e..d92ef0c0accdb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -35,6 +35,8 @@ import org.apache.spark.internal.config import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} +import org.apache.spark.shuffle.api.ShuffleDriverComponents +import org.apache.spark.shuffle.sort.lifecycle.LocalDiskShuffleDriverComponents import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util._ @@ -195,6 +197,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi var securityMgr: SecurityManager = null var scheduler: DAGScheduler = null var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null + var shuffleDriverComponents: ShuffleDriverComponents = null /** * Set of cache locations to return from our mock BlockManagerMaster. @@ -236,7 +239,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi init(new SparkConf()) } - def init(testConf: SparkConf): Unit = { + def init(testConf: SparkConf) { + init(testConf, (conf, bmMaster) => new LocalDiskShuffleDriverComponents(conf, bmMaster)) + } + + def init( + testConf: SparkConf, + shuffleDriverComponents: (SparkConf, BlockManagerMaster) => ShuffleDriverComponents): Unit = { sc = new SparkContext("local[2]", "DAGSchedulerSuite", testConf) sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() @@ -250,7 +259,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi results.clear() securityMgr = new SecurityManager(conf) broadcastManager = new BroadcastManager(true, conf, securityMgr) - mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) { + mapOutputTracker = new MapOutputTrackerMaster( + conf, broadcastManager, true, sc.env.shuffleDataIo.driver()) { override def sendTracker(message: Any): Unit = { // no-op, just so we can stop this to avoid leaking threads } @@ -261,7 +271,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env) + sc.env, + shuffleDriverComponents(testConf, blockManagerMaster)) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } @@ -674,7 +685,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env) + sc.env, + shuffleDriverComponents) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler) val jobId = submit(new MyRDD(sc, 1, Nil), Array(0)) cancel(jobId) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index b158b6c291d55..1d383c55dbdc8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -39,7 +39,9 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} +import org.apache.spark.shuffle.api.ShuffleDriverComponents import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.sort.lifecycle.LocalDiskShuffleDriverComponents import org.apache.spark.storage.StorageLevel._ import org.apache.spark.util.Utils @@ -52,9 +54,10 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite protected var rpcEnv: RpcEnv = null protected var master: BlockManagerMaster = null + protected var driverComponents: ShuffleDriverComponents = null + protected var mapOutputTracker: MapOutputTrackerMaster = null protected lazy val securityMgr = new SecurityManager(conf) protected lazy val bcastManager = new BroadcastManager(true, conf, securityMgr) - protected lazy val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) protected lazy val shuffleManager = new SortShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped @@ -101,6 +104,9 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus(conf))), conf, true) + driverComponents = new LocalDiskShuffleDriverComponents(master) + mapOutputTracker = new MapOutputTrackerMaster( + conf, bcastManager, true, driverComponents) allStores.clear() } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 40c6424b0cb13..322bafc2f14a4 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -50,7 +50,9 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} +import org.apache.spark.shuffle.api.ShuffleDriverComponents import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.sort.lifecycle.LocalDiskShuffleDriverComponents import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer @@ -68,9 +70,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val allStores = ArrayBuffer[BlockManager]() var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null + var driverComponents: ShuffleDriverComponents = null + var mapOutputTracker: MapOutputTrackerMaster = null val securityMgr = new SecurityManager(new SparkConf(false)) val bcastManager = new BroadcastManager(true, new SparkConf(false), securityMgr) - val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true) val shuffleManager = new SortShuffleManager(new SparkConf(false)) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test @@ -130,6 +133,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus(conf))), conf, true) + driverComponents = new LocalDiskShuffleDriverComponents(master) + mapOutputTracker = new MapOutputTrackerMaster( + new SparkConf(false), bcastManager, true, driverComponents) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index f0960d00a75f5..a259260db7706 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -39,7 +39,9 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{KryoSerializer, SerializerManager} +import org.apache.spark.shuffle.api.ShuffleDriverComponents import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.sort.lifecycle.LocalDiskShuffleDriverComponents import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ @@ -70,7 +72,6 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) val streamId = 1 val securityMgr = new SecurityManager(conf, encryptionKey) val broadcastManager = new BroadcastManager(true, conf, securityMgr) - val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) val shuffleManager = new SortShuffleManager(conf) val serializer = new KryoSerializer(conf) var serializerManager = new SerializerManager(serializer, conf, encryptionKey) @@ -78,6 +79,8 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) val blockManagerSize = 10000000 val blockManagerBuffer = new ArrayBuffer[BlockManager]() + var mapOutputTracker: MapOutputTrackerMaster = null + var driverComponents: ShuffleDriverComponents = null var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null @@ -91,6 +94,9 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus(conf))), conf, true) + driverComponents = new LocalDiskShuffleDriverComponents(blockManagerMaster) + mapOutputTracker = new MapOutputTrackerMaster( + conf, broadcastManager, true, driverComponents) storageLevel = StorageLevel.MEMORY_ONLY_SER blockManager = createBlockManager(blockManagerSize, conf)