Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25299] Add shuffle map output un-registration hooks upon fetch failure #609

Merged
merged 7 commits into from
Oct 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@

public interface ShuffleDriverComponents {

enum MapOutputUnregistrationStrategy {
MAP_OUTPUT_ONLY,
EXECUTOR,
HOST,
}

/**
* @return additional SparkConf values necessary for the executors.
*/
Expand All @@ -39,7 +33,18 @@ default void registerShuffle(int shuffleId) throws IOException {}

default void removeShuffle(int shuffleId, boolean blocking) throws IOException {}

default MapOutputUnregistrationStrategy unregistrationStrategyOnFetchFailure() {
return MapOutputUnregistrationStrategy.EXECUTOR;
/**
* Indicates whether or not the data stored for the given map output is available outside
* of the host of the mapper executor.
*
* @return true if it can be verified that the map output is stored outside of the mapper
* AND if the map output is available in such an external location; false otherwise.
*/
default boolean checkIfMapOutputStoredOutsideExecutor(int shuffleId, int mapId) {
return false;
}

default boolean unregisterOutputOnHostOnFetchFailure() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public LocalDiskShuffleDataIO(SparkConf sparkConf) {

@Override
public ShuffleDriverComponents driver() {
return new LocalDiskShuffleDriverComponents();
return new LocalDiskShuffleDriverComponents(sparkConf);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,39 @@

import java.util.Map;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.shuffle.api.ShuffleDriverComponents;
import org.apache.spark.internal.config.package$;
import org.apache.spark.storage.BlockManagerMaster;

public class LocalDiskShuffleDriverComponents implements ShuffleDriverComponents {

private BlockManagerMaster blockManagerMaster;
private boolean shouldUnregisterOutputOnHostOnFetchFailure;
private final SparkConf sparkConf;

public LocalDiskShuffleDriverComponents(SparkConf sparkConf) {
this.sparkConf = sparkConf;
}

@VisibleForTesting
public LocalDiskShuffleDriverComponents(BlockManagerMaster blockManagerMaster) {
this.sparkConf = new SparkConf(false);
this.blockManagerMaster = blockManagerMaster;
}

@VisibleForTesting
public LocalDiskShuffleDriverComponents(
SparkConf sparkConf, BlockManagerMaster blockManagerMaster) {
this.sparkConf = sparkConf;
this.blockManagerMaster = blockManagerMaster;
}

@Override
public Map<String, String> initializeApplication() {
blockManagerMaster = SparkEnv.get().blockManager().master();
this.shouldUnregisterOutputOnHostOnFetchFailure =
SparkEnv.get().blockManager().externalShuffleServiceEnabled()
&& (boolean) SparkEnv.get().conf()
.get(package$.MODULE$.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE());
return ImmutableMap.of();
}

Expand All @@ -48,11 +62,20 @@ public void removeShuffle(int shuffleId, boolean blocking) {
}

@Override
public MapOutputUnregistrationStrategy unregistrationStrategyOnFetchFailure() {
if (shouldUnregisterOutputOnHostOnFetchFailure) {
return MapOutputUnregistrationStrategy.HOST;
}
return MapOutputUnregistrationStrategy.EXECUTOR;
public boolean unregisterOutputOnHostOnFetchFailure() {
boolean unregisterOutputOnHostOnFetchFailure = Boolean.parseBoolean(
sparkConf.get(
org.apache.spark.internal.config.package$.MODULE$
.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE().key(),
org.apache.spark.internal.config.package$.MODULE$
.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE().defaultValueString()));
boolean externalShuffleServiceEnabled = Boolean.parseBoolean(
sparkConf.get(
org.apache.spark.internal.config.package$.MODULE$
.SHUFFLE_SERVICE_ENABLED().key(),
org.apache.spark.internal.config.package$.MODULE$
.SHUFFLE_SERVICE_ENABLED().defaultValueString()));
return unregisterOutputOnHostOnFetchFailure && externalShuffleServiceEnabled;
}

private void checkInitialized() {
Expand Down
44 changes: 21 additions & 23 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.internal.config._
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.shuffle.api.ShuffleDriverComponents
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockAttemptId, ShuffleBlockId}
import org.apache.spark.util._

Expand Down Expand Up @@ -125,31 +126,15 @@ private class ShuffleStatus(numPartitions: Int) {
}
}

/**
* Removes all shuffle outputs associated with this host. Note that this will also remove
* outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsOnHost(host: String): Unit = {
removeOutputsByFilter(x => x.host == host)
}

/**
* Removes all map outputs associated with the specified executor. Note that this will also
* remove outputs which are served by an external shuffle server (if one exists), as they are
* still registered with that execId.
*/
def removeOutputsOnExecutor(execId: String): Unit = synchronized {
removeOutputsByFilter(x => x.executorId == execId)
}

/**
* Removes all shuffle outputs which satisfies the filter. Note that this will also
* remove outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized {
type MapId = Int
def removeOutputsByFilter(f: (MapId, BlockManagerId) => Boolean): Unit = synchronized {
for (mapId <- 0 until mapStatuses.length) {
if (mapStatuses(mapId) != null && mapStatuses(mapId).location != null
&& f(mapStatuses(mapId).location)) {
&& f(mapId, mapStatuses(mapId).location)) {
decrementNumAvailableOutputs(mapStatuses(mapId).location)
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
Expand Down Expand Up @@ -368,7 +353,8 @@ private[spark] object ExecutorShuffleStatus extends Enumeration {
private[spark] class MapOutputTrackerMaster(
conf: SparkConf,
broadcastManager: BroadcastManager,
isLocal: Boolean)
isLocal: Boolean,
val shuffleDriverComponents: ShuffleDriverComponents)
extends MapOutputTracker(conf) {

// The size at which we use Broadcast to send the map output statuses to the executors
Expand Down Expand Up @@ -487,7 +473,7 @@ private[spark] class MapOutputTrackerMaster(
def unregisterAllMapOutput(shuffleId: Int) {
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.removeOutputsByFilter(x => true)
shuffleStatus.removeOutputsByFilter((x, y) => true)
incrementEpoch()
case None =>
throw new SparkException(
Expand Down Expand Up @@ -527,7 +513,13 @@ private[spark] class MapOutputTrackerMaster(
* outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsOnHost(host: String): Unit = {
shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) }
shuffleStatuses.foreach { case (shuffleId, shuffleStatus) =>
shuffleStatus.removeOutputsByFilter(
(mapId, location) => {
location.host == host &&
!shuffleDriverComponents.checkIfMapOutputStoredOutsideExecutor(shuffleId, mapId)
})
}
incrementEpoch()
}

Expand All @@ -537,7 +529,13 @@ private[spark] class MapOutputTrackerMaster(
* registered with this execId.
*/
def removeOutputsOnExecutor(execId: String): Unit = {
shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) }
shuffleStatuses.foreach { case (shuffleId, shuffleStatus) =>
shuffleStatus.removeOutputsByFilter(
(mapId, location) => {
location.executorId == execId &&
!shuffleDriverComponents.checkIfMapOutputStoredOutsideExecutor(shuffleId, mapId)
})
}
incrementEpoch()
}

Expand Down
6 changes: 1 addition & 5 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,7 @@ class SparkContext(config: SparkConf) extends SafeLogging {
executorEnvs ++= _conf.getExecutorEnv
executorEnvs("SPARK_USER") = sparkUser

val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS)
val maybeIO = Utils.loadExtensions(
classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses")
_shuffleDriverComponents = maybeIO.head.driver()
_shuffleDriverComponents = _env.shuffleDataIo.driver()
_shuffleDriverComponents.initializeApplication().asScala.foreach {
case (k, v) => _conf.set(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX + k, v) }

Expand Down
9 changes: 7 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.{ShuffleDataIOUtils, ShuffleManager}
import org.apache.spark.shuffle.api.ShuffleDataIO
import org.apache.spark.storage._
import org.apache.spark.util.{RpcUtils, Utils}

Expand Down Expand Up @@ -71,6 +72,7 @@ class SparkEnv (
val metricsSystem: MetricsSystem,
val memoryManager: MemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val shuffleDataIo: ShuffleDataIO,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it weird to have the shuffleDataIO as part of the developer API? My concern is that users could use shuffleDAtaIO.driver() in the executor and vice versa, but maybe it's not a problem.

If you wanted to get around this, you could pass the driver components, or pass in a function, as part of the method call to MapOutputTracker since you already a driver components in the DAGScheduler

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to pass functional modules around through method calls - modules should be dependency injected at construction time.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the rationale for not passing functional modules through method calls?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ownership of the module becomes unclear - the dependency tree should be more or less static and clear.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to eventually have ShuffleDataIO in this API, but it'll be tricky when it comes to upstream Spark because our APIs will still be in experimental status for now - ok with deferring to the community to see what they think here, cause I can't think of much better at all to set up the dependency injection.

val conf: SparkConf) extends Logging {

private[spark] var isStopped = false
Expand Down Expand Up @@ -340,8 +342,10 @@ object SparkEnv extends Logging {

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

val shuffleDataIo = ShuffleDataIOUtils.loadShuffleDataIO(conf)

val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
new MapOutputTrackerMaster(conf, broadcastManager, isLocal, shuffleDataIo.driver())
} else {
new MapOutputTrackerWorker(conf)
}
Expand Down Expand Up @@ -419,6 +423,7 @@ object SparkEnv extends Logging {
metricsSystem,
memoryManager,
outputCommitCoordinator,
shuffleDataIo,
conf)

// Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
Expand Down
42 changes: 14 additions & 28 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData}
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.shuffle.api.ShuffleDriverComponents.MapOutputUnregistrationStrategy
import org.apache.spark.shuffle.api.ShuffleDriverComponents
import org.apache.spark.shuffle.sort.lifecycle.LocalDiskShuffleDriverComponents
import org.apache.spark.storage._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
import org.apache.spark.util._
Expand Down Expand Up @@ -121,17 +122,19 @@ private[spark] class DAGScheduler(
mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv,
shuffleDriverComponents: ShuffleDriverComponents,
clock: Clock = new SystemClock())
extends Logging {

def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
this(
sc,
taskScheduler,
sc.listenerBus,
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
sc.env.blockManager.master,
sc.env)
sc = sc,
taskScheduler = taskScheduler,
listenerBus = sc.listenerBus,
mapOutputTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
blockManagerMaster = sc.env.blockManager.master,
env = sc.env,
shuffleDriverComponents = sc.env.shuffleDataIo.driver())
}

def this(sc: SparkContext) = this(sc, sc.taskScheduler)
Expand Down Expand Up @@ -171,8 +174,6 @@ private[spark] class DAGScheduler(

private[scheduler] val activeJobs = new HashSet[ActiveJob]

private[scheduler] val shuffleDriverComponents = sc.shuffleDriverComponents

/**
* Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids
* and its values are arrays indexed by partition numbers. Each array value is the set of
Expand All @@ -199,14 +200,6 @@ private[spark] class DAGScheduler(
/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY)

/**
* Whether to unregister all the outputs on the host in condition that we receive a FetchFailure,
* this is set default to false, which means, we only unregister the outputs related to the exact
* executor(instead of the host) on a FetchFailure.
*/
private[scheduler] val unRegisterOutputOnHostOnFetchFailure =
sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE)

/**
* Number of consecutive stage attempts allowed before a stage is aborted.
*/
Expand Down Expand Up @@ -1673,8 +1666,7 @@ private[spark] class DAGScheduler(
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
if (bmAddress.executorId == null) {
if (shuffleDriverComponents.unregistrationStrategyOnFetchFailure() ==
MapOutputUnregistrationStrategy.HOST) {
if (shuffleDriverComponents.unregisterOutputOnHostOnFetchFailure()) {
val currentEpoch = task.epoch
val host = bmAddress.host
logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch))
Expand All @@ -1683,10 +1675,7 @@ private[spark] class DAGScheduler(
}
} else {
val hostToUnregisterOutputs =
if (shuffleDriverComponents.unregistrationStrategyOnFetchFailure() ==
MapOutputUnregistrationStrategy.HOST) {
// We had a fetch failure with the external shuffle service, so we
// assume all shuffle data on the node is bad.
if (shuffleDriverComponents.unregisterOutputOnHostOnFetchFailure()) {
Some(bmAddress.host)
} else {
// Unregister shuffle data just for one executor (we don't have any
Expand Down Expand Up @@ -1866,11 +1855,8 @@ private[spark] class DAGScheduler(
logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch))
mapOutputTracker.removeOutputsOnHost(host)
case None =>
if (shuffleDriverComponents.unregistrationStrategyOnFetchFailure() ==
MapOutputUnregistrationStrategy.EXECUTOR) {
logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
mapOutputTracker.removeOutputsOnExecutor(execId)
}
logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
mapOutputTracker.removeOutputsOnExecutor(execId)
}
clearCacheLocs()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle

import org.apache.spark.SparkConf
import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS
import org.apache.spark.shuffle.api.ShuffleDataIO
import org.apache.spark.util.Utils

private[spark] object ShuffleDataIOUtils {

/**
* The prefix of spark config keys that are passed from the driver to the executor.
*/
val SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin."

def loadShuffleDataIO(conf: SparkConf): ShuffleDataIO = {
val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS)
val maybeIO = Utils.loadExtensions(
classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
require(maybeIO.nonEmpty, s"At least one valid shuffle plugin must be specified by config " +
s"${SHUFFLE_IO_PLUGIN_CLASS.key}, but $configuredPluginClasses resulted in zero valid " +
s"plugins.")
maybeIO.head
}
}
Loading