Skip to content
7 changes: 5 additions & 2 deletions yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,11 @@ private[spark] class Client(
env("SPARK_JAVA_OPTS") = value
}
// propagate PYSPARK_DRIVER_PYTHON and PYSPARK_PYTHON to driver in cluster mode
sys.env.get("PYSPARK_DRIVER_PYTHON").foreach(env("PYSPARK_DRIVER_PYTHON") = _)
sys.env.get("PYSPARK_PYTHON").foreach(env("PYSPARK_PYTHON") = _)
Seq("PYSPARK_DRIVER_PYTHON", "PYSPARK_PYTHON").foreach { envname =>
if (!env.contains(envname)) {
sys.env.get(envname).foreach(env(envname) = _)
}
}
}

sys.env.get(ENV_DIST_CLASSPATH).foreach { dcp =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
testPySpark(false)
}

test("run Python application in yarn-cluster mode using " +
" spark.yarn.appMasterEnv to override local envvar") {
testPySpark(
clientMode = false,
extraConf = Map(
"spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON"
-> sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python"),
"spark.yarn.appMasterEnv.PYSPARK_PYTHON"
-> sys.env.getOrElse("PYSPARK_PYTHON", "python")),
extraEnv = Map(
"PYSPARK_DRIVER_PYTHON" -> "not python",
"PYSPARK_PYTHON" -> "not python"))
}

test("user class path first in client mode") {
testUseClassPathFirst(true)
}
Expand Down Expand Up @@ -188,7 +202,10 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
checkResult(finalState, executorResult, "ORIGINAL")
}

private def testPySpark(clientMode: Boolean): Unit = {
private def testPySpark(
clientMode: Boolean,
extraConf: Map[String, String] = Map(),
extraEnv: Map[String, String] = Map()): Unit = {
val primaryPyFile = new File(tempDir, "test.py")
Files.write(TEST_PYFILE, primaryPyFile, StandardCharsets.UTF_8)

Expand All @@ -199,9 +216,9 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
val pythonPath = Seq(
s"$sparkHome/python/lib/py4j-0.10.1-src.zip",
s"$sparkHome/python")
val extraEnv = Map(
val extraEnvVars = Map(
"PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator),
"PYTHONPATH" -> pythonPath.mkString(File.pathSeparator))
"PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) ++ extraEnv

val moduleDir =
if (clientMode) {
Expand All @@ -223,7 +240,8 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
val finalState = runSpark(clientMode, primaryPyFile.getAbsolutePath(),
sparkArgs = Seq("--py-files" -> pyFiles),
appArgs = Seq(result.getAbsolutePath()),
extraEnv = extraEnv)
extraEnv = extraEnvVars,
extraConf = extraConf)
checkResult(finalState, result)
}

Expand Down