Skip to content

Commit

Permalink
[SPARK-25299] Driver lifecycle api (#533)
Browse files Browse the repository at this point in the history
Introduce driver shuffle lifecycle APIs
  • Loading branch information
yifeih authored and bulldozer-bot[bot] committed May 7, 2019
1 parent 3a760e7 commit ab9131d
Show file tree
Hide file tree
Showing 11 changed files with 201 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,8 @@
*/
@Experimental
public interface ShuffleDataIO {
String SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin.";

ShuffleDriverComponents driver();
ShuffleExecutorComponents executor();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.shuffle;

import java.io.IOException;
import java.util.Map;

public interface ShuffleDriverComponents {

/**
* @return additional SparkConf values necessary for the executors.
*/
Map<String, String> initializeApplication();

void cleanupApplication() throws IOException;

void removeShuffleData(int shuffleId, boolean blocking) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import org.apache.spark.annotation.Experimental;

import java.util.Map;

/**
* :: Experimental ::
* An interface for building shuffle support for Executors
Expand All @@ -27,7 +29,7 @@
*/
@Experimental
public interface ShuffleExecutorComponents {
void initializeExecutor(String appId, String execId);
void initializeExecutor(String appId, String execId, Map<String, String> extraConfigs);

ShuffleWriteSupport writes();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.spark.shuffle.sort.io;

import org.apache.spark.SparkConf;
import org.apache.spark.api.shuffle.ShuffleDriverComponents;
import org.apache.spark.api.shuffle.ShuffleExecutorComponents;
import org.apache.spark.api.shuffle.ShuffleDataIO;
import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents;

public class DefaultShuffleDataIO implements ShuffleDataIO {

Expand All @@ -33,4 +35,9 @@ public DefaultShuffleDataIO(SparkConf sparkConf) {
public ShuffleExecutorComponents executor() {
return new DefaultShuffleExecutorComponents(sparkConf);
}

@Override
public ShuffleDriverComponents driver() {
return new DefaultShuffleDriverComponents();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.apache.spark.shuffle.io.DefaultShuffleReadSupport;
import org.apache.spark.storage.BlockManager;

import java.util.Map;

public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents {

private final SparkConf sparkConf;
Expand All @@ -41,7 +43,7 @@ public DefaultShuffleExecutorComponents(SparkConf sparkConf) {
}

@Override
public void initializeExecutor(String appId, String execId) {
public void initializeExecutor(String appId, String execId, Map<String, String> extraConfigs) {
blockManager = SparkEnv.get().blockManager();
mapOutputTracker = SparkEnv.get().mapOutputTracker();
serializerManager = SparkEnv.get().serializerManager();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.sort.lifecycle;

import com.google.common.collect.ImmutableMap;
import org.apache.spark.SparkEnv;
import org.apache.spark.api.shuffle.ShuffleDriverComponents;
import org.apache.spark.storage.BlockManagerMaster;

import java.io.IOException;
import java.util.Map;

public class DefaultShuffleDriverComponents implements ShuffleDriverComponents {

private BlockManagerMaster blockManagerMaster;

@Override
public Map<String, String> initializeApplication() {
blockManagerMaster = SparkEnv.get().blockManager().master();
return ImmutableMap.of();
}

@Override
public void cleanupApplication() {
// do nothing
}

@Override
public void removeShuffleData(int shuffleId, boolean blocking) throws IOException {
checkInitialized();
blockManagerMaster.removeShuffle(shuffleId, blocking);
}

private void checkInitialized() {
if (blockManagerMaster == null) {
throw new IllegalStateException("Driver components must be initialized before using");
}
}
}
8 changes: 5 additions & 3 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Scheduled

import scala.collection.JavaConverters._

import org.apache.spark.api.shuffle.ShuffleDriverComponents
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
Expand Down Expand Up @@ -58,7 +59,9 @@ private class CleanupTaskWeakReference(
* to be processed when the associated object goes out of scope of the application. Actual
* cleanup is performed in a separate daemon thread.
*/
private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
private[spark] class ContextCleaner(
sc: SparkContext,
shuffleDriverComponents: ShuffleDriverComponents) extends Logging {

/**
* A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they
Expand Down Expand Up @@ -222,7 +225,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
try {
logDebug("Cleaning shuffle " + shuffleId)
mapOutputTrackerMaster.unregisterShuffle(shuffleId)
blockManagerMaster.removeShuffle(shuffleId, blocking)
shuffleDriverComponents.removeShuffleData(shuffleId, blocking)
listeners.asScala.foreach(_.shuffleCleaned(shuffleId))
logInfo("Cleaned shuffle " + shuffleId)
} catch {
Expand Down Expand Up @@ -270,7 +273,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
}

private def blockManagerMaster = sc.env.blockManager.master
private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
}
Expand Down
13 changes: 12 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.conda.CondaEnvironment
import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions
import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{CondaRunner, LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat}
Expand Down Expand Up @@ -215,6 +216,7 @@ class SparkContext(config: SparkConf) extends SafeLogging {
private var _shutdownHookRef: AnyRef = _
private var _statusStore: AppStatusStore = _
private var _heartbeater: Heartbeater = _
private var _shuffleDriverComponents: ShuffleDriverComponents = _

/* ------------------------------------------------------------------------------------- *
| Accessors and public fields. These provide access to the internal state of the |
Expand Down Expand Up @@ -491,6 +493,14 @@ class SparkContext(config: SparkConf) extends SafeLogging {
executorEnvs ++= _conf.getExecutorEnv
executorEnvs("SPARK_USER") = sparkUser

val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS)
val maybeIO = Utils.loadExtensions(
classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses")
_shuffleDriverComponents = maybeIO.head.driver()
_shuffleDriverComponents.initializeApplication().asScala.foreach {
case (k, v) => _conf.set(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX + k, v) }

// We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will
// retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640)
_heartbeatReceiver = env.rpcEnv.setupEndpoint(
Expand Down Expand Up @@ -560,7 +570,7 @@ class SparkContext(config: SparkConf) extends SafeLogging {

_cleaner =
if (_conf.get(CLEANER_REFERENCE_TRACKING)) {
Some(new ContextCleaner(this))
Some(new ContextCleaner(this, _shuffleDriverComponents))
} else {
None
}
Expand Down Expand Up @@ -1950,6 +1960,7 @@ class SparkContext(config: SparkConf) extends SafeLogging {
}
_heartbeater = null
}
_shuffleDriverComponents.cleanupApplication()
if (env != null && _heartbeatReceiver != null) {
Utils.tryLogNonFatalError {
env.rpcEnv.stop(_heartbeatReceiver)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.shuffle.sort

import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

import org.apache.spark._
import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents}
import org.apache.spark.internal.{config, Logging}
Expand Down Expand Up @@ -225,7 +227,12 @@ private[spark] object SortShuffleManager extends Logging {
classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses")
val executorComponents = maybeIO.head.executor()
executorComponents.initializeExecutor(conf.getAppId, SparkEnv.get.executorId)
val extraConfigs = conf.getAllWithPrefix(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX)
.toMap
executorComponents.initializeExecutor(
conf.getAppId,
SparkEnv.get.executorId,
extraConfigs.asJava)
executorComponents
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
/**
* A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup.
*/
private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) {
private class SaveAccumContextCleaner(sc: SparkContext) extends
ContextCleaner(sc, null) {
private val accumsRegistered = new ArrayBuffer[Long]

override def registerAccumulatorForCleanup(a: AccumulatorV2[_, _]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle

import java.util

import com.google.common.collect.ImmutableMap

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleWriteSupport}
import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS
import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport

class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext {
test(s"test serialization of shuffle initialization conf to executors") {
val testConf = new SparkConf()
.setAppName("testing")
.setMaster("local-cluster[2,1,1024]")
.set(SHUFFLE_IO_PLUGIN_CLASS, "org.apache.spark.shuffle.TestShuffleDataIO")

sc = new SparkContext(testConf)

sc.parallelize(Seq((1, "one"), (2, "two"), (3, "three")), 3)
.groupByKey()
.collect()
}
}

class TestShuffleDriverComponents extends ShuffleDriverComponents {
override def initializeApplication(): util.Map[String, String] =
ImmutableMap.of("test-key", "test-value")

override def cleanupApplication(): Unit = {}

override def removeShuffleData(shuffleId: Int, blocking: Boolean): Unit = {}
}

class TestShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO {
override def driver(): ShuffleDriverComponents = new TestShuffleDriverComponents()

override def executor(): ShuffleExecutorComponents =
new TestShuffleExecutorComponents(sparkConf)
}

class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecutorComponents {
override def initializeExecutor(appId: String, execId: String,
extraConfigs: util.Map[String, String]): Unit = {
assert(extraConfigs.get("test-key") == "test-value")
}

override def writes(): ShuffleWriteSupport = {
val blockManager = SparkEnv.get.blockManager
val blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager)
new DefaultShuffleWriteSupport(sparkConf, blockResolver)
}
}

0 comments on commit ab9131d

Please sign in to comment.