Skip to content

Commit

Permalink
[SPARK-25299] Add shuffle map output un-registration hooks upon fetch…
Browse files Browse the repository at this point in the history
… failure (#609)

We realized that there's complexity with whether or not map outputs should be unregistered, and thus should be recomputed.

Previously, we were using a boolean, then, a three-way switch. But these modes do not capture a lot of intricacies with how this should work - in particular, for our async upload proof of concept, we end up unregistering all the map outputs written by an executor, despite the fact that this would invalidate and re-write all the map outputs that were persisted to the remote storage system.
  • Loading branch information
mccheah authored and bulldozer-bot[bot] committed Oct 4, 2019
1 parent 5abde44 commit d551551
Show file tree
Hide file tree
Showing 15 changed files with 191 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@

public interface ShuffleDriverComponents {

enum MapOutputUnregistrationStrategy {
MAP_OUTPUT_ONLY,
EXECUTOR,
HOST,
}

/**
* @return additional SparkConf values necessary for the executors.
*/
Expand All @@ -39,7 +33,18 @@ default void registerShuffle(int shuffleId) throws IOException {}

default void removeShuffle(int shuffleId, boolean blocking) throws IOException {}

default MapOutputUnregistrationStrategy unregistrationStrategyOnFetchFailure() {
return MapOutputUnregistrationStrategy.EXECUTOR;
/**
* Indicates whether or not the data stored for the given map output is available outside
* of the host of the mapper executor.
*
* @return true if it can be verified that the map output is stored outside of the mapper
* AND if the map output is available in such an external location; false otherwise.
*/
default boolean checkIfMapOutputStoredOutsideExecutor(int shuffleId, int mapId) {
return false;
}

default boolean unregisterOutputOnHostOnFetchFailure() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public LocalDiskShuffleDataIO(SparkConf sparkConf) {

@Override
public ShuffleDriverComponents driver() {
return new LocalDiskShuffleDriverComponents();
return new LocalDiskShuffleDriverComponents(sparkConf);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,39 @@

import java.util.Map;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.shuffle.api.ShuffleDriverComponents;
import org.apache.spark.internal.config.package$;
import org.apache.spark.storage.BlockManagerMaster;

public class LocalDiskShuffleDriverComponents implements ShuffleDriverComponents {

private BlockManagerMaster blockManagerMaster;
private boolean shouldUnregisterOutputOnHostOnFetchFailure;
private final SparkConf sparkConf;

public LocalDiskShuffleDriverComponents(SparkConf sparkConf) {
this.sparkConf = sparkConf;
}

@VisibleForTesting
public LocalDiskShuffleDriverComponents(BlockManagerMaster blockManagerMaster) {
this.sparkConf = new SparkConf(false);
this.blockManagerMaster = blockManagerMaster;
}

@VisibleForTesting
public LocalDiskShuffleDriverComponents(
SparkConf sparkConf, BlockManagerMaster blockManagerMaster) {
this.sparkConf = sparkConf;
this.blockManagerMaster = blockManagerMaster;
}

@Override
public Map<String, String> initializeApplication() {
blockManagerMaster = SparkEnv.get().blockManager().master();
this.shouldUnregisterOutputOnHostOnFetchFailure =
SparkEnv.get().blockManager().externalShuffleServiceEnabled()
&& (boolean) SparkEnv.get().conf()
.get(package$.MODULE$.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE());
return ImmutableMap.of();
}

Expand All @@ -48,11 +62,20 @@ public void removeShuffle(int shuffleId, boolean blocking) {
}

@Override
public MapOutputUnregistrationStrategy unregistrationStrategyOnFetchFailure() {
if (shouldUnregisterOutputOnHostOnFetchFailure) {
return MapOutputUnregistrationStrategy.HOST;
}
return MapOutputUnregistrationStrategy.EXECUTOR;
public boolean unregisterOutputOnHostOnFetchFailure() {
boolean unregisterOutputOnHostOnFetchFailure = Boolean.parseBoolean(
sparkConf.get(
org.apache.spark.internal.config.package$.MODULE$
.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE().key(),
org.apache.spark.internal.config.package$.MODULE$
.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE().defaultValueString()));
boolean externalShuffleServiceEnabled = Boolean.parseBoolean(
sparkConf.get(
org.apache.spark.internal.config.package$.MODULE$
.SHUFFLE_SERVICE_ENABLED().key(),
org.apache.spark.internal.config.package$.MODULE$
.SHUFFLE_SERVICE_ENABLED().defaultValueString()));
return unregisterOutputOnHostOnFetchFailure && externalShuffleServiceEnabled;
}

private void checkInitialized() {
Expand Down
44 changes: 21 additions & 23 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.internal.config._
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.shuffle.api.ShuffleDriverComponents
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockAttemptId, ShuffleBlockId}
import org.apache.spark.util._

Expand Down Expand Up @@ -125,31 +126,15 @@ private class ShuffleStatus(numPartitions: Int) {
}
}

/**
* Removes all shuffle outputs associated with this host. Note that this will also remove
* outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsOnHost(host: String): Unit = {
removeOutputsByFilter(x => x.host == host)
}

/**
* Removes all map outputs associated with the specified executor. Note that this will also
* remove outputs which are served by an external shuffle server (if one exists), as they are
* still registered with that execId.
*/
def removeOutputsOnExecutor(execId: String): Unit = synchronized {
removeOutputsByFilter(x => x.executorId == execId)
}

/**
* Removes all shuffle outputs which satisfies the filter. Note that this will also
* remove outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized {
type MapId = Int
def removeOutputsByFilter(f: (MapId, BlockManagerId) => Boolean): Unit = synchronized {
for (mapId <- 0 until mapStatuses.length) {
if (mapStatuses(mapId) != null && mapStatuses(mapId).location != null
&& f(mapStatuses(mapId).location)) {
&& f(mapId, mapStatuses(mapId).location)) {
decrementNumAvailableOutputs(mapStatuses(mapId).location)
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
Expand Down Expand Up @@ -368,7 +353,8 @@ private[spark] object ExecutorShuffleStatus extends Enumeration {
private[spark] class MapOutputTrackerMaster(
conf: SparkConf,
broadcastManager: BroadcastManager,
isLocal: Boolean)
isLocal: Boolean,
val shuffleDriverComponents: ShuffleDriverComponents)
extends MapOutputTracker(conf) {

// The size at which we use Broadcast to send the map output statuses to the executors
Expand Down Expand Up @@ -487,7 +473,7 @@ private[spark] class MapOutputTrackerMaster(
def unregisterAllMapOutput(shuffleId: Int) {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeOutputsByFilter(x => true)
shuffleStatus.removeOutputsByFilter((x, y) => true)
incrementEpoch()
case None =>
throw new SparkException(
Expand Down Expand Up @@ -527,7 +513,13 @@ private[spark] class MapOutputTrackerMaster(
* outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsOnHost(host: String): Unit = {
shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) }
shuffleStatuses.foreach { case (shuffleId, shuffleStatus) =>
shuffleStatus.removeOutputsByFilter(
(mapId, location) => {
location.host == host &&
!shuffleDriverComponents.checkIfMapOutputStoredOutsideExecutor(shuffleId, mapId)
})
}
incrementEpoch()
}

Expand All @@ -537,7 +529,13 @@ private[spark] class MapOutputTrackerMaster(
* registered with this execId.
*/
def removeOutputsOnExecutor(execId: String): Unit = {
shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) }
shuffleStatuses.foreach { case (shuffleId, shuffleStatus) =>
shuffleStatus.removeOutputsByFilter(
(mapId, location) => {
location.executorId == execId &&
!shuffleDriverComponents.checkIfMapOutputStoredOutsideExecutor(shuffleId, mapId)
})
}
incrementEpoch()
}

Expand Down
6 changes: 1 addition & 5 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,7 @@ class SparkContext(config: SparkConf) extends SafeLogging {
executorEnvs ++= _conf.getExecutorEnv
executorEnvs("SPARK_USER") = sparkUser

val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS)
val maybeIO = Utils.loadExtensions(
classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses")
_shuffleDriverComponents = maybeIO.head.driver()
_shuffleDriverComponents = _env.shuffleDataIo.driver()
_shuffleDriverComponents.initializeApplication().asScala.foreach {
case (k, v) => _conf.set(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX + k, v) }

Expand Down
9 changes: 7 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.{ShuffleDataIOUtils, ShuffleManager}
import org.apache.spark.shuffle.api.ShuffleDataIO
import org.apache.spark.storage._
import org.apache.spark.util.{RpcUtils, Utils}

Expand Down Expand Up @@ -71,6 +72,7 @@ class SparkEnv (
val metricsSystem: MetricsSystem,
val memoryManager: MemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val shuffleDataIo: ShuffleDataIO,
val conf: SparkConf) extends Logging {

private[spark] var isStopped = false
Expand Down Expand Up @@ -340,8 +342,10 @@ object SparkEnv extends Logging {

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

val shuffleDataIo = ShuffleDataIOUtils.loadShuffleDataIO(conf)

val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
new MapOutputTrackerMaster(conf, broadcastManager, isLocal, shuffleDataIo.driver())
} else {
new MapOutputTrackerWorker(conf)
}
Expand Down Expand Up @@ -419,6 +423,7 @@ object SparkEnv extends Logging {
metricsSystem,
memoryManager,
outputCommitCoordinator,
shuffleDataIo,
conf)

// Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
Expand Down
42 changes: 14 additions & 28 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData}
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.shuffle.api.ShuffleDriverComponents.MapOutputUnregistrationStrategy
import org.apache.spark.shuffle.api.ShuffleDriverComponents
import org.apache.spark.shuffle.sort.lifecycle.LocalDiskShuffleDriverComponents
import org.apache.spark.storage._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
import org.apache.spark.util._
Expand Down Expand Up @@ -121,17 +122,19 @@ private[spark] class DAGScheduler(
mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv,
shuffleDriverComponents: ShuffleDriverComponents,
clock: Clock = new SystemClock())
extends Logging {

def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
this(
sc,
taskScheduler,
sc.listenerBus,
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
sc.env.blockManager.master,
sc.env)
sc = sc,
taskScheduler = taskScheduler,
listenerBus = sc.listenerBus,
mapOutputTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
blockManagerMaster = sc.env.blockManager.master,
env = sc.env,
shuffleDriverComponents = sc.env.shuffleDataIo.driver())
}

def this(sc: SparkContext) = this(sc, sc.taskScheduler)
Expand Down Expand Up @@ -171,8 +174,6 @@ private[spark] class DAGScheduler(

private[scheduler] val activeJobs = new HashSet[ActiveJob]

private[scheduler] val shuffleDriverComponents = sc.shuffleDriverComponents

/**
* Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids
* and its values are arrays indexed by partition numbers. Each array value is the set of
Expand All @@ -199,14 +200,6 @@ private[spark] class DAGScheduler(
/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY)

/**
* Whether to unregister all the outputs on the host in condition that we receive a FetchFailure,
* this is set default to false, which means, we only unregister the outputs related to the exact
* executor(instead of the host) on a FetchFailure.
*/
private[scheduler] val unRegisterOutputOnHostOnFetchFailure =
sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE)

/**
* Number of consecutive stage attempts allowed before a stage is aborted.
*/
Expand Down Expand Up @@ -1673,8 +1666,7 @@ private[spark] class DAGScheduler(
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
if (bmAddress.executorId == null) {
if (shuffleDriverComponents.unregistrationStrategyOnFetchFailure() ==
MapOutputUnregistrationStrategy.HOST) {
if (shuffleDriverComponents.unregisterOutputOnHostOnFetchFailure()) {
val currentEpoch = task.epoch
val host = bmAddress.host
logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch))
Expand All @@ -1683,10 +1675,7 @@ private[spark] class DAGScheduler(
}
} else {
val hostToUnregisterOutputs =
if (shuffleDriverComponents.unregistrationStrategyOnFetchFailure() ==
MapOutputUnregistrationStrategy.HOST) {
// We had a fetch failure with the external shuffle service, so we
// assume all shuffle data on the node is bad.
if (shuffleDriverComponents.unregisterOutputOnHostOnFetchFailure()) {
Some(bmAddress.host)
} else {
// Unregister shuffle data just for one executor (we don't have any
Expand Down Expand Up @@ -1866,11 +1855,8 @@ private[spark] class DAGScheduler(
logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch))
mapOutputTracker.removeOutputsOnHost(host)
case None =>
if (shuffleDriverComponents.unregistrationStrategyOnFetchFailure() ==
MapOutputUnregistrationStrategy.EXECUTOR) {
logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
mapOutputTracker.removeOutputsOnExecutor(execId)
}
logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
mapOutputTracker.removeOutputsOnExecutor(execId)
}
clearCacheLocs()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle

import org.apache.spark.SparkConf
import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS
import org.apache.spark.shuffle.api.ShuffleDataIO
import org.apache.spark.util.Utils

private[spark] object ShuffleDataIOUtils {

/**
* The prefix of spark config keys that are passed from the driver to the executor.
*/
val SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin."

def loadShuffleDataIO(conf: SparkConf): ShuffleDataIO = {
val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS)
val maybeIO = Utils.loadExtensions(
classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
require(maybeIO.nonEmpty, s"At least one valid shuffle plugin must be specified by config " +
s"${SHUFFLE_IO_PLUGIN_CLASS.key}, but $configuredPluginClasses resulted in zero valid " +
s"plugins.")
maybeIO.head
}
}
Loading

0 comments on commit d551551

Please sign in to comment.