diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 2761f0da5f772..c448eee5fca23 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -82,8 +82,10 @@ class SparkContext(config: SparkConf) extends Logging { // The call site where this SparkContext was constructed. private val creationSite: CallSite = Utils.getCallSite() - // In order to prevent SparkContext from being created in executors. - SparkContext.assertOnDriver() + if (!config.get(ALLOW_SPARK_CONTEXT_IN_EXECUTORS)) { + // In order to prevent SparkContext from being created in executors. + SparkContext.assertOnDriver() + } // In order to prevent multiple SparkContexts from being active at the same time, mark this // context as having started construction. diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bfdd06021757a..38eb90c57ef68 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1814,4 +1814,10 @@ package object config { .bytesConf(ByteUnit.BYTE) .createOptional + private[spark] val ALLOW_SPARK_CONTEXT_IN_EXECUTORS = + ConfigBuilder("spark.driver.allowSparkContextInExecutors") + .doc("If set to true, SparkContext can be created in executors.") + .version("3.0.1") + .booleanConf + .createWithDefault(true) } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 2b1e110a39466..5533f42859bd1 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -951,17 +951,26 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } - test("SPARK-32160: Disallow to create SparkContext in executors") { + test("SPARK-32160: Disallow to create SparkContext in executors if the config is set") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]")) val error = intercept[SparkException] { sc.range(0, 1).foreach { _ => - new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + new SparkContext(new SparkConf().setAppName("test").setMaster("local") + .set(ALLOW_SPARK_CONTEXT_IN_EXECUTORS, false)) } }.getMessage() assert(error.contains("SparkContext should only be created and accessed on the driver.")) } + + test("SPARK-32160: Allow to create SparkContext in executors") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local-cluster[3, 1, 1024]")) + + sc.range(0, 1).foreach { _ => + new SparkContext(new SparkConf().setAppName("test").setMaster("local")).stop() + } + } } object SparkContextSuite { diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ecd171a0b8b41..3c447a10b7058 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -119,8 +119,10 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ... ValueError:... """ - # In order to prevent SparkContext from being created in executors. - SparkContext._assert_on_driver() + if (conf is not None and + conf.get("spark.driver.allowSparkContextInExecutors", "true").lower() != "true"): + # In order to prevent SparkContext from being created in executors. + SparkContext._assert_on_driver() self._callsite = first_spark_call() or CallSite(None, None, None) if gateway is not None and gateway.gateway_parameters.auth_token is None: diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py index 303635ddab12c..9468b2511ce03 100644 --- a/python/pyspark/tests/test_context.py +++ b/python/pyspark/tests/test_context.py @@ -268,13 +268,29 @@ def test_resources(self): self.assertEqual(len(resources), 0) def test_disallow_to_create_spark_context_in_executors(self): - # SPARK-32160: SparkContext should not be created in executors. + # SPARK-32160: SparkContext should not created in executors if the config is set. + + def create_spark_context(): + conf = SparkConf().set("spark.driver.allowSparkContextInExecutors", "false") + with SparkContext(conf=conf): + pass + with SparkContext("local-cluster[3, 1, 1024]") as sc: with self.assertRaises(Exception) as context: - sc.range(2).foreach(lambda _: SparkContext()) + sc.range(2).foreach(lambda _: create_spark_context()) self.assertIn("SparkContext should only be created and accessed on the driver.", str(context.exception)) + def test_allow_to_create_spark_context_in_executors(self): + # SPARK-32160: SparkContext can be created in executors. + + def create_spark_context(): + with SparkContext(): + pass + + with SparkContext("local-cluster[3, 1, 1024]") as sc: + sc.range(2).foreach(lambda _: create_spark_context()) + class ContextTestsWithResources(unittest.TestCase): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e5d8710bfb69f..6c809c8592522 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -29,6 +29,7 @@ import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ALLOW_SPARK_CONTEXT_IN_EXECUTORS import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog @@ -900,7 +901,13 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { - assertOnDriver() + val sparkConf = new SparkConf() + options.foreach { case (k, v) => sparkConf.set(k, v) } + + if (!sparkConf.get(ALLOW_SPARK_CONTEXT_IN_EXECUTORS)) { + assertOnDriver() + } + // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -919,9 +926,6 @@ object SparkSession extends Logging { // No active nor global default session. Create a new one. val sparkContext = userSuppliedContext.getOrElse { - val sparkConf = new SparkConf() - options.foreach { case (k, v) => sparkConf.set(k, v) } - // set a random app name if not given. if (!sparkConf.contains("spark.app.name")) { sparkConf.setAppName(java.util.UUID.randomUUID().toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 0a522fdbdeed8..6983cda5a35cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.internal.config.ALLOW_SPARK_CONTEXT_IN_EXECUTORS import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf._ @@ -240,4 +241,27 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { assert(session.conf.get(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31532-2") assert(session.conf.get(WAREHOUSE_PATH) === "SPARK-31532-db-2") } + + test("SPARK-32160: Disallow to create SparkSession in executors if the config is set") { + val session = SparkSession.builder().master("local-cluster[3, 1, 1024]").getOrCreate() + + val error = intercept[SparkException] { + session.range(1).foreach { v => + SparkSession.builder.master("local") + .config(ALLOW_SPARK_CONTEXT_IN_EXECUTORS.key, false).getOrCreate() + () + } + }.getMessage() + + assert(error.contains("SparkSession should only be created and accessed on the driver.")) + } + + test("SPARK-32160: Allow to create SparkSession in executors") { + val session = SparkSession.builder().master("local-cluster[3, 1, 1024]").getOrCreate() + + session.range(1).foreach { v => + SparkSession.builder.master("local").getOrCreate().stop() + () + } + } }