diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java index 4cb40f6dd00b8..dd7c0ac7320cb 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDataIO.java @@ -27,5 +27,8 @@ */ @Experimental public interface ShuffleDataIO { + String SHUFFLE_SPARK_CONF_PREFIX = "spark.shuffle.plugin."; + + ShuffleDriverComponents driver(); ShuffleExecutorComponents executor(); } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java new file mode 100644 index 0000000000000..6a0ec8d44fd4f --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; +import java.util.Map; + +public interface ShuffleDriverComponents { + + /** + * @return additional SparkConf values necessary for the executors. + */ + Map initializeApplication(); + + void cleanupApplication() throws IOException; + + void removeShuffleData(int shuffleId, boolean blocking) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java index 4fc20bad9938b..d6a017bce1878 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java @@ -19,6 +19,8 @@ import org.apache.spark.annotation.Experimental; +import java.util.Map; + /** * :: Experimental :: * An interface for building shuffle support for Executors @@ -27,7 +29,7 @@ */ @Experimental public interface ShuffleExecutorComponents { - void initializeExecutor(String appId, String execId); + void initializeExecutor(String appId, String execId, Map extraConfigs); ShuffleWriteSupport writes(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java index 906600c0f15fc..7c124c1fe68bc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java @@ -18,8 +18,10 @@ package org.apache.spark.shuffle.sort.io; import org.apache.spark.SparkConf; +import org.apache.spark.api.shuffle.ShuffleDriverComponents; import org.apache.spark.api.shuffle.ShuffleExecutorComponents; import org.apache.spark.api.shuffle.ShuffleDataIO; +import org.apache.spark.shuffle.sort.lifecycle.DefaultShuffleDriverComponents; public class DefaultShuffleDataIO implements ShuffleDataIO { @@ -33,4 +35,9 @@ public DefaultShuffleDataIO(SparkConf sparkConf) { public ShuffleExecutorComponents executor() { return new DefaultShuffleExecutorComponents(sparkConf); } + + @Override + public ShuffleDriverComponents driver() { + return new DefaultShuffleDriverComponents(); + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index 76e87a6740259..bb2db97fa9c95 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -24,6 +24,8 @@ import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.storage.BlockManager; +import java.util.Map; + public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { private final SparkConf sparkConf; @@ -35,7 +37,7 @@ public DefaultShuffleExecutorComponents(SparkConf sparkConf) { } @Override - public void initializeExecutor(String appId, String execId) { + public void initializeExecutor(String appId, String execId, Map extraConfigs) { blockManager = SparkEnv.get().blockManager(); blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java new file mode 100644 index 0000000000000..a3eddc8ec930e --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.lifecycle; + +import com.google.common.collect.ImmutableMap; +import org.apache.spark.SparkEnv; +import org.apache.spark.api.shuffle.ShuffleDriverComponents; +import org.apache.spark.storage.BlockManagerMaster; + +import java.io.IOException; +import java.util.Map; + +public class DefaultShuffleDriverComponents implements ShuffleDriverComponents { + + private BlockManagerMaster blockManagerMaster; + + @Override + public Map initializeApplication() { + blockManagerMaster = SparkEnv.get().blockManager().master(); + return ImmutableMap.of(); + } + + @Override + public void cleanupApplication() { + // do nothing + } + + @Override + public void removeShuffleData(int shuffleId, boolean blocking) throws IOException { + checkInitialized(); + blockManagerMaster.removeShuffle(shuffleId, blocking); + } + + private void checkInitialized() { + if (blockManagerMaster == null) { + throw new IllegalStateException("Driver components must be initialized before using"); + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 305ec46a364a2..fa28e54116d25 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -23,6 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, Scheduled import scala.collection.JavaConverters._ +import org.apache.spark.api.shuffle.ShuffleDriverComponents import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -58,7 +59,9 @@ private class CleanupTaskWeakReference( * to be processed when the associated object goes out of scope of the application. Actual * cleanup is performed in a separate daemon thread. */ -private[spark] class ContextCleaner(sc: SparkContext) extends Logging { +private[spark] class ContextCleaner( + sc: SparkContext, + shuffleDriverComponents: ShuffleDriverComponents) extends Logging { /** * A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they @@ -222,7 +225,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning shuffle " + shuffleId) mapOutputTrackerMaster.unregisterShuffle(shuffleId) - blockManagerMaster.removeShuffle(shuffleId, blocking) + shuffleDriverComponents.removeShuffleData(shuffleId, blocking) listeners.asScala.foreach(_.shuffleCleaned(shuffleId)) logInfo("Cleaned shuffle " + shuffleId) } catch { @@ -270,7 +273,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 56c66e0e99db9..999f180193d84 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -43,6 +43,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.conda.CondaEnvironment import org.apache.spark.api.conda.CondaEnvironment.CondaSetupInstructions +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{CondaRunner, LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} @@ -215,6 +216,7 @@ class SparkContext(config: SparkConf) extends SafeLogging { private var _shutdownHookRef: AnyRef = _ private var _statusStore: AppStatusStore = _ private var _heartbeater: Heartbeater = _ + private var _shuffleDriverComponents: ShuffleDriverComponents = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -491,6 +493,14 @@ class SparkContext(config: SparkConf) extends SafeLogging { executorEnvs ++= _conf.getExecutorEnv executorEnvs("SPARK_USER") = sparkUser + val configuredPluginClasses = conf.get(SHUFFLE_IO_PLUGIN_CLASS) + val maybeIO = Utils.loadExtensions( + classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) + require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") + _shuffleDriverComponents = maybeIO.head.driver() + _shuffleDriverComponents.initializeApplication().asScala.foreach { + case (k, v) => _conf.set(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX + k, v) } + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) _heartbeatReceiver = env.rpcEnv.setupEndpoint( @@ -560,7 +570,7 @@ class SparkContext(config: SparkConf) extends SafeLogging { _cleaner = if (_conf.get(CLEANER_REFERENCE_TRACKING)) { - Some(new ContextCleaner(this)) + Some(new ContextCleaner(this, _shuffleDriverComponents)) } else { None } @@ -1950,6 +1960,7 @@ class SparkContext(config: SparkConf) extends SafeLogging { } _heartbeater = null } + _shuffleDriverComponents.cleanupApplication() if (env != null && _heartbeatReceiver != null) { Utils.tryLogNonFatalError { env.rpcEnv.stop(_heartbeatReceiver) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 5da7b5cb35e6d..b5cd0fd558825 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ + import org.apache.spark._ import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents} import org.apache.spark.internal.{config, Logging} @@ -219,7 +221,12 @@ private[spark] object SortShuffleManager extends Logging { classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") val executorComponents = maybeIO.head.executor() - executorComponents.initializeExecutor(conf.getAppId, SparkEnv.get.executorId) + val extraConfigs = conf.getAllWithPrefix(ShuffleDataIO.SHUFFLE_SPARK_CONF_PREFIX) + .toMap + executorComponents.initializeExecutor( + conf.getAppId, + SparkEnv.get.executorId, + extraConfigs.asJava) executorComponents } } diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 62824a5bec9d1..28cbeeda7a88d 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -210,7 +210,8 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { /** * A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup. */ - private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) { + private class SaveAccumContextCleaner(sc: SparkContext) extends + ContextCleaner(sc, null) { private val accumsRegistered = new ArrayBuffer[Long] override def registerAccumulatorForCleanup(a: AccumulatorV2[_, _]): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala new file mode 100644 index 0000000000000..dbb954945a8b6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.util + +import com.google.common.collect.ImmutableMap + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleWriteSupport} +import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport + +class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext { + test(s"test serialization of shuffle initialization conf to executors") { + val testConf = new SparkConf() + .setAppName("testing") + .setMaster("local-cluster[2,1,1024]") + .set(SHUFFLE_IO_PLUGIN_CLASS, "org.apache.spark.shuffle.TestShuffleDataIO") + + sc = new SparkContext(testConf) + + sc.parallelize(Seq((1, "one"), (2, "two"), (3, "three")), 3) + .groupByKey() + .collect() + } +} + +class TestShuffleDriverComponents extends ShuffleDriverComponents { + override def initializeApplication(): util.Map[String, String] = + ImmutableMap.of("test-key", "test-value") + + override def cleanupApplication(): Unit = {} + + override def removeShuffleData(shuffleId: Int, blocking: Boolean): Unit = {} +} + +class TestShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { + override def driver(): ShuffleDriverComponents = new TestShuffleDriverComponents() + + override def executor(): ShuffleExecutorComponents = + new TestShuffleExecutorComponents(sparkConf) +} + +class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecutorComponents { + override def initializeExecutor(appId: String, execId: String, + extraConfigs: util.Map[String, String]): Unit = { + assert(extraConfigs.get("test-key") == "test-value") + } + + override def writes(): ShuffleWriteSupport = { + val blockManager = SparkEnv.get.blockManager + val blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager) + new DefaultShuffleWriteSupport(sparkConf, blockResolver) + } +}