Skip to content

Commit

Permalink
Add non-local test.
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcelo Vanzin committed Oct 23, 2019
1 parent ccb950e commit e38a8da
Showing 1 changed file with 51 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.internal.plugin

import java.io.File
import java.nio.charset.StandardCharsets
import java.util.{Map => JMap}

import scala.collection.JavaConverters._
import scala.concurrent.duration._

import com.codahale.metrics.Gauge
import com.google.common.io.Files
import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{mock, spy, verify, when}
import org.scalatest.BeforeAndAfterEach
Expand All @@ -32,6 +35,7 @@ import org.apache.spark.{ExecutorPlugin => _, _}
import org.apache.spark.api.plugin._
import org.apache.spark.internal.config._
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.util.Utils

class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with LocalSparkContext {

Expand Down Expand Up @@ -117,6 +121,53 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
assert(TestSparkPlugin.driverPlugin != null)
}

test("plugin initialization in non-local mode") {
val path = Utils.createTempDir()

val conf = new SparkConf()
.setAppName(getClass().getName())
.set(SparkLauncher.SPARK_MASTER, "local-cluster[2,1,1024]")
.set(PLUGINS, Seq(classOf[NonLocalModeSparkPlugin].getName()))
.set(NonLocalModeSparkPlugin.TEST_PATH_CONF, path.getAbsolutePath())

sc = new SparkContext(conf)
TestUtils.waitUntilExecutorsUp(sc, 2, 10000)

eventually(timeout(10.seconds), interval(100.millis)) {
val children = path.listFiles()
assert(children != null)
assert(children.length >= 3)
}
}
}

class NonLocalModeSparkPlugin extends SparkPlugin {

override def driverPlugin(): DriverPlugin = {
new DriverPlugin() {
override def init(sc: SparkContext, ctx: PluginContext): JMap[String, String] = {
NonLocalModeSparkPlugin.writeFile(ctx.conf(), ctx.executorID())
Map.empty.asJava
}
}
}

override def executorPlugin(): ExecutorPlugin = {
new ExecutorPlugin() {
override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = {
NonLocalModeSparkPlugin.writeFile(ctx.conf(), ctx.executorID())
}
}
}
}

object NonLocalModeSparkPlugin {
val TEST_PATH_CONF = "spark.nonLocalPlugin.path"

def writeFile(conf: SparkConf, id: String): Unit = {
val path = conf.get(TEST_PATH_CONF)
Files.write(id, new File(path, id), StandardCharsets.UTF_8)
}
}

class TestSparkPlugin extends SparkPlugin {
Expand Down

0 comments on commit e38a8da

Please sign in to comment.