Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bin/utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ function gatherSparkSubmitOpts() {
--master | --deploy-mode | --class | --name | --jars | --packages | --py-files | --files | \
--conf | --repositories | --properties-file | --driver-memory | --driver-java-options | \
--driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \
--total-executor-cores | --executor-cores | --queue | --num-executors | --archives)
--total-executor-cores | --executor-cores | --queue | --num-executors | --archives | \
--proxy-user)
if [[ $# -lt 2 ]]; then
"$SUBMIT_USAGE_FUNCTION"
exit 1;
Expand Down
1 change: 1 addition & 0 deletions bin/windows-utils.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ SET opts="%opts:~1,-1% \<--conf\> \<--properties-file\> \<--driver-memory\> \<--
SET opts="%opts:~1,-1% \<--driver-library-path\> \<--driver-class-path\> \<--executor-memory\>"
SET opts="%opts:~1,-1% \<--driver-cores\> \<--total-executor-cores\> \<--executor-cores\> \<--queue\>"
SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\> \<--packages\> \<--repositories\>"
SET opts="%opts:~1,-1% \<--proxy-user\>"

echo %1 | findstr %opts% >nul
if %ERRORLEVEL% equ 0 (
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.hadoop.io.Text

import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.network.sasl.SecretKeyHolder
import org.apache.spark.util.Utils

/**
* Spark class responsible for security.
Expand Down Expand Up @@ -203,7 +204,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf)

// always add the current user and SPARK_USER to the viewAcls
private val defaultAclUsers = Set[String](System.getProperty("user.name", ""),
Option(System.getenv("SPARK_USER")).getOrElse("")).filter(!_.isEmpty)
Utils.getCurrentUserName())

setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", ""))
setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", ""))
Expand Down
16 changes: 5 additions & 11 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli

// log out Spark Version in Spark driver log
logInfo(s"Running Spark version $SPARK_VERSION")

private[spark] val conf = config.clone()
conf.validateSettings()

Expand Down Expand Up @@ -335,11 +335,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
executorEnvs ++= conf.getExecutorEnv

// Set SPARK_USER for user who is running SparkContext.
val sparkUser = Option {
Option(System.getenv("SPARK_USER")).getOrElse(System.getProperty("user.name"))
}.getOrElse {
SparkContext.SPARK_UNKNOWN_USER
}
val sparkUser = Utils.getCurrentUserName()
executorEnvs("SPARK_USER") = sparkUser

// Create and start the scheduler
Expand Down Expand Up @@ -826,7 +822,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
assertNotStopped()
// The call to new NewHadoopJob automatically adds security credentials to conf,
// The call to new NewHadoopJob automatically adds security credentials to conf,
// so we don't need to explicitly add them ourselves
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
Expand Down Expand Up @@ -1606,8 +1602,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@deprecated("use defaultMinPartitions", "1.0.0")
def defaultMinSplits: Int = math.min(defaultParallelism, 2)

/**
* Default min number of partitions for Hadoop RDDs when not given by user
/**
* Default min number of partitions for Hadoop RDDs when not given by user
* Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2.
* The reasons for this are discussed in https://github.com/mesos/spark/pull/718
*/
Expand Down Expand Up @@ -1824,8 +1820,6 @@ object SparkContext extends Logging {

private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel"

private[spark] val SPARK_UNKNOWN_USER = "<unknown>"

private[spark] val DRIVER_IDENTIFIER = "<driver>"

// The following deprecated objects have already been copied to `object AccumulatorParam` to
Expand Down
19 changes: 7 additions & 12 deletions core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,13 @@ class SparkHadoopUtil extends Logging {
* do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems
*/
def runAsSparkUser(func: () => Unit) {
val user = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
if (user != SparkContext.SPARK_UNKNOWN_USER) {
logDebug("running as user: " + user)
val ugi = UserGroupInformation.createRemoteUser(user)
transferCredentials(UserGroupInformation.getCurrentUser(), ugi)
ugi.doAs(new PrivilegedExceptionAction[Unit] {
def run: Unit = func()
})
} else {
logDebug("running as SPARK_UNKNOWN_USER")
func()
}
val user = Utils.getCurrentUserName()
logDebug("running as user: " + user)
val ugi = UserGroupInformation.createRemoteUser(user)
transferCredentials(UserGroupInformation.getCurrentUser(), ugi)
ugi.doAs(new PrivilegedExceptionAction[Unit] {
def run: Unit = func()
})
}

def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) {
Expand Down
56 changes: 47 additions & 9 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
package org.apache.spark.deploy

import java.io.{File, PrintStream}
import java.lang.reflect.{InvocationTargetException, Modifier}
import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException}
import java.net.URL
import java.security.PrivilegedExceptionAction

import scala.collection.mutable.{ArrayBuffer, HashMap, Map}

import org.apache.hadoop.fs.Path
import org.apache.hadoop.security.UserGroupInformation
import org.apache.ivy.Ivy
import org.apache.ivy.core.LogOptions
import org.apache.ivy.core.module.descriptor._
Expand Down Expand Up @@ -79,7 +81,7 @@ object SparkSubmit {
private val CLASS_NOT_FOUND_EXIT_STATUS = 101

// Exposed for testing
private[spark] var exitFn: () => Unit = () => System.exit(-1)
private[spark] var exitFn: () => Unit = () => System.exit(1)
private[spark] var printStream: PrintStream = System.err
private[spark] def printWarning(str: String) = printStream.println("Warning: " + str)
private[spark] def printErrorAndExit(str: String) = {
Expand Down Expand Up @@ -126,6 +128,34 @@ object SparkSubmit {
*/
private[spark] def submit(args: SparkSubmitArguments): Unit = {
val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args)

def doRunMain(): Unit = {
if (args.proxyUser != null) {
val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser,
UserGroupInformation.getCurrentUser())
try {
proxyUser.doAs(new PrivilegedExceptionAction[Unit]() {
override def run(): Unit = {
runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose)
}
})
} catch {
case e: Exception =>
// Hadoop's AuthorizationException suppresses the exception's stack trace, which
// makes the message printed to the output by the JVM not very helpful. Instead,
// detect exceptions with empty stack traces here, and treat them differently.
if (e.getStackTrace().length == 0) {
printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}")
exitFn()
} else {
throw e
}
}
} else {
runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose)
}
}

// In standalone cluster mode, there are two submission gateways:
// (1) The traditional Akka gateway using o.a.s.deploy.Client as a wrapper
// (2) The new REST-based gateway introduced in Spark 1.3
Expand All @@ -134,7 +164,7 @@ object SparkSubmit {
if (args.isStandaloneCluster && args.useRest) {
try {
printStream.println("Running Spark using the REST application submission protocol.")
runMain(childArgs, childClasspath, sysProps, childMainClass)
doRunMain()
} catch {
// Fail over to use the legacy submission gateway
case e: SubmitRestConnectionException =>
Expand All @@ -145,7 +175,7 @@ object SparkSubmit {
}
// In all other modes, just run the main class as prepared
} else {
runMain(childArgs, childClasspath, sysProps, childMainClass)
doRunMain()
}
}

Expand Down Expand Up @@ -457,7 +487,7 @@ object SparkSubmit {
childClasspath: Seq[String],
sysProps: Map[String, String],
childMainClass: String,
verbose: Boolean = false) {
verbose: Boolean): Unit = {
if (verbose) {
printStream.println(s"Main class:\n$childMainClass")
printStream.println(s"Arguments:\n${childArgs.mkString("\n")}")
Expand Down Expand Up @@ -507,13 +537,21 @@ object SparkSubmit {
if (!Modifier.isStatic(mainMethod.getModifiers)) {
throw new IllegalStateException("The main method in the given main class must be static")
}

def findCause(t: Throwable): Throwable = t match {
case e: UndeclaredThrowableException =>
if (e.getCause() != null) findCause(e.getCause()) else e
case e: InvocationTargetException =>
if (e.getCause() != null) findCause(e.getCause()) else e
case e: Throwable =>
e
}

try {
mainMethod.invoke(null, childArgs.toArray)
} catch {
case e: InvocationTargetException => e.getCause match {
case cause: Throwable => throw cause
case null => throw e
}
case t: Throwable =>
throw findCause(t)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var pyFiles: String = null
var action: SparkSubmitAction = null
val sparkProperties: HashMap[String, String] = new HashMap[String, String]()
var proxyUser: String = null

// Standalone cluster mode only
var supervise: Boolean = false
Expand Down Expand Up @@ -405,6 +406,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
parse(tail)

case ("--proxy-user") :: value :: tail =>
proxyUser = value
parse(tail)

case ("--help" | "-h") :: tail =>
printUsageAndExit(0)

Expand Down Expand Up @@ -476,6 +481,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
|
| --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G).
|
| --proxy-user NAME User to impersonate when submitting the application.
|
| --help, -h Show this help message and exit
| --verbose, -v Print additional debug output
|
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.commons.lang3.SystemUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
import org.apache.hadoop.security.UserGroupInformation
import org.apache.log4j.PropertyConfigurator
import org.eclipse.jetty.util.MultiException
import org.json4s._
Expand Down Expand Up @@ -1986,6 +1987,16 @@ private[spark] object Utils extends Logging {
throw new SparkException("Invalid master URL: " + sparkUrl, e)
}
}

/**
* Returns the current user name. This is the currently logged in user, unless that's been
* overridden by the `SPARK_USER` environment variable.
*/
def getCurrentUserName(): String = {
Option(System.getenv("SPARK_USER"))
.getOrElse(UserGroupInformation.getCurrentUser().getUserName())
}

}

/**
Expand Down