diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index bbdc9158d8e2b..f1033ecdf6fab 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1624,8 +1624,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported * filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * If addToCurrentClassLoader is true, attempt to add the new class to the current threads' class + * loader. In general adding to the current threads' class loader will impact all other + * application threads unless they have explicitly changed their class loader. */ def addJar(path: String) { + addJar(path, false) + } + + def addJar(path: String, addToCurrentClassLoader: Boolean) { if (path == null) { logWarning("null specified as parameter to addJar") } else { @@ -1680,6 +1687,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (key != null) { addedJars(key) = System.currentTimeMillis logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) + if (addToCurrentClassLoader) { + val currentCL = Utils.getContextOrSparkClassLoader + currentCL match { + case cl: MutableURLClassLoader => cl.addURL(new URI(key).toURL()) + case _ => logWarning(s"Unsupported cl $currentCL will not update jars thread cl") + } + } } } postEnvironmentUpdate() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 529d16b480399..7238331daa3ef 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -810,6 +810,26 @@ def addPyFile(self, path): import importlib importlib.invalidate_caches() + def addJar(self, path, addToCurrentClassLoader=False): + """ + Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. + The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + If addToCurrentClassLoader is true, attempt to add the new class to the current threads' + class loader. In general adding to the current threads' class loader will impact all other + application threads unless they have explicitly changed their class loader. + """ + self._jsc.sc().addJar(path, addToCurrentClassLoader) + + def _loadClass(self, className): + """ + .. note:: Experimental + + Loads a JVM class using the MutableClass loader used by spark. + This function exists because Py4J uses a different class loader. + """ + self._jvm.org.apache.spark.util.Utils.getContextOrSparkClassLoader().loadClass(className) + def setCheckpointDir(self, dirName): """ Set the directory under which RDDs are going to be checkpointed. The diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 5bd94476597ab..36336ff18eb70 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -410,6 +410,17 @@ def test_add_file_locally(self): with open(download_path) as test_file: self.assertEqual("Hello World!\n", test_file.readline()) + def test_add_jar(self): + # We shouldn't be able to load anything from the package before it is added + self.assertRaises(Exception, + lambda: sc._loadClass("sparkR.test.hello")) + # Load the new jar + path = os.path.join(SPARK_HOME, "./R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar") + self.sc.addJar(path, True) + self.assertTrue(self.sc._jsc.sc().addedJars().toString().find("sparktestjar") != -1) + # Try and load a different one of the classes + cls = self.sc._loadClass("sparkR.test.basicFunction") + def test_add_py_file_locally(self): # To ensure that we're actually testing addPyFile's effects, check that # this fails due to `userlibrary` not being on the Python path: