diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 4e3fe00a2e9b..d595cb2a4d48 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -404,6 +404,7 @@ export("as.DataFrame", "setCurrentDatabase", "spark.lapply", "spark.addFile", + "spark.addJar", "spark.getSparkFilesRootDirectory", "spark.getSparkFiles", "sql", diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 8349b57a30a9..b4750b06f937 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -319,6 +319,32 @@ spark.addFile <- function(path, recursive = FALSE) { invisible(callJMethod(sc, "addFile", suppressWarnings(normalizePath(path)), recursive)) } + +#' Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. +#' +#' The \code{path} passed can be either a local file, a file in HDFS (or other Hadoop-supported +#' filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. +#' If \code{addToCurrentClassLoader} is true, add the jar to the current threads' classloader. In +#' general adding to the current threads' class loader will impact all other application threads +#' unless they have explicitly changed their class loader. +#' +#' @rdname spark.addJar +#' @param path The path of the jar to be added +#' @param addToCurrentClassLoader Whether to add the jar to the current driver classloader. +#' @export +#' @examples +#'\dontrun{ +#' spark.addJar("/path/to/something.jar", TRUE) +#'} +#' @note spark.addJar since 2.2.0 +spark.addJar <- function(path, addToCurrentClassLoader = FALSE) { + normalizedPath <- suppressWarnings(normalizePath(path)) + sc <- callJMethod(getSparkContext(), "sc") + invisible(callJMethod(sc, "addJar", normalizedPath, addToCurrentClassLoader)) +} + + + #' Get the root directory that contains files added through spark.addFile. #' #' @rdname spark.getSparkFilesRootDirectory diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index 710485d56685..90008f103283 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -167,6 +167,18 @@ test_that("spark.lapply should perform simple transforms", { sparkR.session.stop() }) +test_that("add jar should work and allow usage of the jar on the driver node", { + sparkR.sparkContext() + + destDir <- file.path(tempdir(), "testjar") + jarName <- callJStatic("org.apache.spark.TestUtils", "createDummyJar", + destDir, "sparkrTests", "DummyClassForAddJarTest") + + spark.addJar(jarName, addToCurrentClassLoader = TRUE) + testClass <- newJObject("sparkrTests.DummyClassForAddJarTest") + expect_true(class(testClass) == "jobj") +}) + test_that("add and get file to be downloaded with Spark job on every node", { sparkR.sparkContext(master = sparkRTestMaster) # Test add file. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b2a26c51d4de..23dc2f4683fd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1801,9 +1801,23 @@ class SparkContext(config: SparkConf) extends Logging { /** * Adds a JAR dependency for all tasks to be executed on this `SparkContext` in the future. * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), - * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. */ - def addJar(path: String) { + def addJar(path: String): Unit = { + addJar(path, false) + } + + /** + * Adds a JAR dependency for all tasks to be executed on this `SparkContext` in the future. + * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), + * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * @param addToCurrentClassLoader if true will add the jar to the current threads' classloader. + * In general adding to the current threads' class loader will + * impact all other application threads unless they have explicitly + * changed their class loader. + */ + @DeveloperApi + def addJar(path: String, addToCurrentClassLoader: Boolean) { def addJarFile(file: File): String = { try { if (!file.exists()) { @@ -1845,6 +1859,21 @@ class SparkContext(config: SparkConf) extends Logging { logInfo(s"Added JAR $path at $key with timestamp $timestamp") postEnvironmentUpdate() } + + if (addToCurrentClassLoader) { + val currentCL = Utils.getContextOrSparkClassLoader + currentCL match { + case cl: MutableURLClassLoader => + val uri = if (path.contains("\\")) { + // For local paths with backslashes on Windows, URI throws an exception + new File(path).toURI + } else { + new URI(path) + } + cl.addURL(uri.toURL) + case _ => logWarning(s"Unsupported cl $currentCL will not update jars thread cl") + } + } } } } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 3f912dc19151..da9326e65c28 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -168,6 +168,27 @@ private[spark] object TestUtils { createCompiledClass(className, destDir, sourceFile, classpathUrls) } + /** Create a dummy compile jar for a given package, classname. Jar will be placed in destDir */ + def createDummyJar(destDir: String, packageName: String, className: String): String = { + val srcDir = new File(destDir, packageName) + srcDir.mkdirs() + val excSource = new JavaSourceFromString(new File(srcDir, className).toURI.getPath, + s"""package $packageName; + | + |public class $className implements java.io.Serializable { + | public static String helloWorld(String arg) { return "Hello " + arg; } + | public static int addStuff(int arg1, int arg2) { return arg1 + arg2; } + |} + """. + stripMargin) + val excFile = createCompiledClass(className, srcDir, excSource, Seq.empty) + val jarFile = new File(destDir, + s"$packageName-$className-%s.jar".format(System.currentTimeMillis())) + val jarURL = createJar(Seq(excFile), jarFile, directoryPrefix = Some(packageName)) + jarURL.toString + } + + /** * Run some code involving jobs submitted to the given context and assert that the jobs spilled. */ diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 979270a527a6..5213d55c4040 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import java.io.File -import java.net.{MalformedURLException, URI} +import java.net.{MalformedURLException, URI, URL} import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -34,7 +34,7 @@ import org.scalatest.concurrent.Eventually import org.scalatest.Matchers._ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{MutableURLClassLoader, ThreadUtils, Utils} class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -309,6 +309,34 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(sc.listJars().head.contains(tmpJar.getName)) } + Seq("local_mode", "non_local_mode").foreach { schedulingMode => + val tempDir = Utils.createTempDir().toString + val master = schedulingMode match { + case "local_mode" => "local" + case "non_local_mode" => "local-cluster[1,1,1024]" + } + val packageName = s"scala_$schedulingMode" + val className = "DummyClass" + val jarURI = TestUtils.createDummyJar(tempDir, packageName, className) + + // ensure we reset the classloader after the test completes + val originalClassLoader = Thread.currentThread.getContextClassLoader + try { + // load the exception from the jar + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + + test(s"jar can be added and used driver side in $schedulingMode") { + sc = new SparkContext(master, "test") + Thread.currentThread().setContextClassLoader(loader) + sc.addJar(jarURI, addToCurrentClassLoader = true) + val cl = Utils.getContextOrSparkClassLoader + cl.loadClass(s"$packageName.$className") + } + } finally { + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { try { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index b089357e7b86..1c2e85075f34 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy import java.io._ import java.net.URI +import java.net.URL import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer @@ -530,26 +531,14 @@ class SparkSubmitSuite Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator) assert(new File(rScriptDir).exists) + val tempDir = Utils.createTempDir().toString // compile a small jar containing a class that will be called from R code. - val tempDir = Utils.createTempDir() - val srcDir = new File(tempDir, "sparkrtest") - srcDir.mkdirs() - val excSource = new JavaSourceFromString(new File(srcDir, "DummyClass").toURI.getPath, - """package sparkrtest; - | - |public class DummyClass implements java.io.Serializable { - | public static String helloWorld(String arg) { return "Hello " + arg; } - | public static int addStuff(int arg1, int arg2) { return arg1 + arg2; } - |} - """.stripMargin) - val excFile = TestUtils.createCompiledClass("DummyClass", srcDir, excSource, Seq.empty) - val jarFile = new File(tempDir, "sparkRTestJar-%s.jar".format(System.currentTimeMillis())) - val jarURL = TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("sparkrtest")) + val jarURL = TestUtils.createDummyJar(tempDir, "sparkrtest", "DummyClass") val args = Seq( "--name", "testApp", "--master", "local", - "--jars", jarURL.toString, + "--jars", jarURL, "--verbose", "--conf", "spark.ui.enabled=false", rScriptDir) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 3be07325f416..554d595b712c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -863,6 +863,21 @@ def addPyFile(self, path): import importlib importlib.invalidate_caches() + def addJar(self, path, addToCurrentClassLoader=False): + """ + Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. + The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + If addToCurrentClassLoader is true, add the jar to the current threads' classloader. + In general adding to the current threads' class loader will impact all other application + threads unless they have explicitly changed their class loader. + + :param path: The path of the jar to be added + :param addToCurrentClassLoader: Whether to add the jar to the current driver classloader. + This defaults to False. + """ + self._jsc.sc().addJar(path, addToCurrentClassLoader) + def setCheckpointDir(self, dirName): """ Set the directory under which RDDs are going to be checkpointed. The diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index bb13de563cdd..dfc530865abc 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -35,6 +35,7 @@ import hashlib from py4j.protocol import Py4JJavaError +from py4j.java_gateway import JavaClass try: import xmlrunner except ImportError: @@ -435,6 +436,19 @@ def test_add_file_locally(self): with open(download_path) as test_file: self.assertEqual("Hello World!\n", test_file.readline()) + def test_add_jar(self): + jvm = self.sc._jvm + # We shouldn't be able to load anything from the package before it is added + self.assertFalse(isinstance(jvm.pysparktests.DummyClass, JavaClass)) + # Generate and compile the test jar + destDir = os.path.join(SPARK_HOME, "python/test_support/jar") + jarName = jvm.org.apache.spark.TestUtils.createDummyJar( + destDir, "pysparktests", "DummyClass") + # Load the new jar + self.sc.addJar(jarName, True) + # Try and load the class + self.assertTrue(isinstance(jvm.pysparktests.DummyClass, JavaClass)) + def test_add_file_recursively_locally(self): path = os.path.join(SPARK_HOME, "python/test_support/hello") self.sc.addFile(path, True)