diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3fc756b9ef40..20aab2330c62 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -409,6 +409,7 @@ export("as.DataFrame", "setCurrentDatabase", "spark.lapply", "spark.addFile", + "spark.addJar", "spark.getSparkFilesRootDirectory", "spark.getSparkFiles", "sql", diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 443c2ff8f9ac..9b9c2a1480a9 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -319,6 +319,32 @@ spark.addFile <- function(path, recursive = FALSE) { invisible(callJMethod(sc, "addFile", suppressWarnings(normalizePath(path)), recursive)) } +#' Adds a JAR dependency for Spark tasks to be executed in the future. +#' +#' The \code{path} passed can be either a local file, a file in HDFS (or other Hadoop-supported +#' filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. +#' If \code{addToCurrentClassLoader} is true, add the jar to the current threads' class loader +#' in the backing JVM. In general adding to the current threads' class loader will impact all +#' other application threads unless they have explicitly changed their class loader. +#' +#' Note: \code{addToCurrentClassLoader} parameter is a developer API, which change or be removed +#' in minor versions of Spark. +#' +#' @rdname spark.addJar +#' @param path The path of the jar to be added +#' @param addToCurrentClassLoader Whether to add the jar to the current driver class loader. +#' @export +#' @examples +#'\dontrun{ +#' spark.addJar("/path/to/something.jar", TRUE) +#'} +#' @note spark.addJar since 2.3.0 +spark.addJar <- function(path, addToCurrentClassLoader = FALSE) { + normalizedPath <- suppressWarnings(normalizePath(path)) + sc <- callJMethod(getSparkContext(), "sc") + invisible(callJMethod(sc, "addJar", normalizedPath, addToCurrentClassLoader)) +} + #' Get the root directory that contains files added through spark.addFile. #' #' @rdname spark.getSparkFilesRootDirectory diff --git a/R/pkg/tests/fulltests/test_context.R b/R/pkg/tests/fulltests/test_context.R index 77635c5a256b..9742296adb0a 100644 --- a/R/pkg/tests/fulltests/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -168,6 +168,20 @@ test_that("spark.lapply should perform simple transforms", { sparkR.session.stop() }) +test_that("add jar should work and allow usage of the jar on the driver node", { + sparkR.sparkContext() + + destDir <- file.path(tempdir(), "testjar") + jarFile <- callJStatic("org.apache.spark.TestUtils", "createDummyJar", + destDir, "sparkrTests", "DummyClassForAddJarTest") + jarPath <- callJMethod(jarFile, "getAbsolutePath") + + spark.addJar(jarPath, addToCurrentClassLoader = TRUE) + testClass <- newJObject("sparkrTests.DummyClassForAddJarTest") + expect_true(class(testClass) == "jobj") + unlink(destDir) +}) + test_that("add and get file to be downloaded with Spark job on every node", { sparkR.sparkContext(master = sparkRTestMaster) # Test add file. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6f25d346e6e5..b6257c8a7a6b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1802,7 +1802,21 @@ class SparkContext(config: SparkConf) extends Logging { * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. */ - def addJar(path: String) { + def addJar(path: String): Unit = { + addJar(path, addToCurrentClassLoader = false) + } + + /** + * Adds a JAR dependency for all tasks to be executed on this `SparkContext` in the future. + * + * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), + * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * @param addToCurrentClassLoader if true will add the jar to the current threads' class loader. + * In general adding to the current threads' class loader will impact all other application + * threads unless they have explicitly changed their class loader. + */ + @DeveloperApi + def addJar(path: String, addToCurrentClassLoader: Boolean): Unit = { def addJarFile(file: File): String = { try { if (!file.exists()) { @@ -1838,12 +1852,21 @@ class SparkContext(config: SparkConf) extends Logging { case _ => path } } + if (key != null) { val timestamp = System.currentTimeMillis if (addedJars.putIfAbsent(key, timestamp).isEmpty) { logInfo(s"Added JAR $path at $key with timestamp $timestamp") postEnvironmentUpdate() } + + if (addToCurrentClassLoader) { + Utils.getContextOrSparkClassLoader match { + case cl: MutableURLClassLoader => cl.addURL(Utils.resolveURI(path).toURL) + case cl => logWarning( + s"Unsupported class loader $cl will not update jars in the thread class loader.") + } + } } } } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index a80016dd22fc..867baab5a0f0 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -168,6 +168,27 @@ private[spark] object TestUtils { createCompiledClass(className, destDir, sourceFile, classpathUrls) } + /** Create a dummy compile jar for a given package, classname. Jar will be placed in destDir */ + def createDummyJar(destDir: String, packageName: String, className: String): File = { + val srcDir = new File(destDir, packageName) + srcDir.mkdirs() + val excSource = new JavaSourceFromString( + new File(srcDir, className).toURI.getPath, + s""" + |package $packageName; + | + |public class $className implements java.io.Serializable { + | public static String helloWorld(String arg) { return "Hello " + arg; } + | public static int addStuff(int arg1, int arg2) { return arg1 + arg2; } + |} + """.stripMargin) + val excFile = createCompiledClass(className, srcDir, excSource, Seq.empty) + val jarFile = new File(destDir, + s"$packageName-$className-%s.jar".format(System.currentTimeMillis())) + createJar(Seq(excFile), jarFile, directoryPrefix = Some(packageName)) + jarFile + } + /** * Run some code involving jobs submitted to the given context and assert that the jobs spilled. */ diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 0ed5f26863da..7339049920a9 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark import java.io.File -import java.net.{MalformedURLException, URI} +import java.net.{MalformedURLException, URI, URL} import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -34,7 +34,7 @@ import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{MutableURLClassLoader, ThreadUtils, Utils} class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventually { @@ -309,6 +309,34 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(sc.listJars().head.contains(tmpJar.getName)) } + Seq("local_mode", "non_local_mode").foreach { schedulingMode => + val tempDir = Utils.createTempDir().toString + val master = schedulingMode match { + case "local_mode" => "local" + case "non_local_mode" => "local-cluster[1,1,1024]" + } + val packageName = s"scala_$schedulingMode" + val className = "DummyClass" + val jarURI = TestUtils.createDummyJar(tempDir, packageName, className).toURI + + // ensure we reset the classloader after the test completes + val originalClassLoader = Thread.currentThread.getContextClassLoader + try { + // load the exception from the jar + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + + test(s"jar can be added and used driver side in $schedulingMode") { + sc = new SparkContext(master, "test") + Thread.currentThread().setContextClassLoader(loader) + sc.addJar(jarURI.toString, addToCurrentClassLoader = true) + val cl = Utils.getContextOrSparkClassLoader + cl.loadClass(s"$packageName.$className") + } + } finally { + Thread.currentThread.setContextClassLoader(originalClassLoader) + } + } + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { try { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index cfbf56fb8c36..e02dfb578fdc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy import java.io._ import java.net.URI +import java.net.URL import java.nio.charset.StandardCharsets import java.nio.file.Files @@ -549,21 +550,9 @@ class SparkSubmitSuite Seq(sparkHome, "R", "pkg", "tests", "fulltests", "jarTest.R").mkString(File.separator) assert(new File(rScriptDir).exists) + val tempDir = Utils.createTempDir().getAbsolutePath // compile a small jar containing a class that will be called from R code. - val tempDir = Utils.createTempDir() - val srcDir = new File(tempDir, "sparkrtest") - srcDir.mkdirs() - val excSource = new JavaSourceFromString(new File(srcDir, "DummyClass").toURI.getPath, - """package sparkrtest; - | - |public class DummyClass implements java.io.Serializable { - | public static String helloWorld(String arg) { return "Hello " + arg; } - | public static int addStuff(int arg1, int arg2) { return arg1 + arg2; } - |} - """.stripMargin) - val excFile = TestUtils.createCompiledClass("DummyClass", srcDir, excSource, Seq.empty) - val jarFile = new File(tempDir, "sparkRTestJar-%s.jar".format(System.currentTimeMillis())) - val jarURL = TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("sparkrtest")) + val jarURL = TestUtils.createDummyJar(tempDir, "sparkrtest", "DummyClass").toURI.toURL val args = Seq( "--name", "testApp", diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a33f6dcf31fc..55c601b98202 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -860,6 +860,23 @@ def addPyFile(self, path): import importlib importlib.invalidate_caches() + def addJar(self, path, addToCurrentClassLoader=False): + """ + Adds a JAR dependency for Spark tasks to be executed in the future. + The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + If `addToCurrentClassLoader` is true, add the jar to the current threads' class loader + in the backing JVM. In general adding to the current threads' class loader will impact all + other application threads unless they have explicitly changed their class loader. + + .. note:: `addToCurrentClassLoader` parameter is a developer API, which change or be removed + in minor versions of Spark. + + :param path: The path of the jar to be added + :param addToCurrentClassLoader: Whether to add the jar to the current driver class loader. + """ + self._jsc.sc().addJar(path, addToCurrentClassLoader) + def setCheckpointDir(self, dirName): """ Set the directory under which RDDs are going to be checkpointed. The diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index da99872da2f0..c7eb8b899233 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -35,6 +35,7 @@ import hashlib from py4j.protocol import Py4JJavaError +from py4j.java_gateway import JavaClass try: import xmlrunner except ImportError: @@ -435,6 +436,22 @@ def test_add_file_locally(self): with open(download_path) as test_file: self.assertEqual("Hello World!\n", test_file.readline()) + def test_add_jar(self): + jvm = self.sc._jvm + # We shouldn't be able to load anything from the package before it is added + self.assertFalse(isinstance(jvm.pysparktests.DummyClass, JavaClass)) + try: + # Generate and compile the test jar + destDir = tempfile.mkdtemp() + jarPath = jvm.org.apache.spark.TestUtils.createDummyJar( + destDir, "pysparktests", "DummyClass").getAbsolutePath() + # Load the new jar + self.sc.addJar(jarPath, True) + # Try and load the class + self.assertTrue(isinstance(jvm.pysparktests.DummyClass, JavaClass)) + finally: + shutil.rmtree(destDir) + def test_add_file_recursively_locally(self): path = os.path.join(SPARK_HOME, "python/test_support/hello") self.sc.addFile(path, True)