Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 133 additions & 33 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@

package org.apache.spark.deploy

import java.io.{File, IOException}
import java.io._
import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException}
import java.net.URL
import java.nio.file.Files
import java.security.PrivilegedExceptionAction
import java.security.{KeyStore, PrivilegedExceptionAction}
import java.security.cert.X509Certificate
import java.text.ParseException
import javax.net.ssl._

import scala.annotation.tailrec
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import scala.util.Properties

import com.google.common.io.ByteStreams
import org.apache.commons.io.FileUtils
import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.conf.{Configuration => HadoopConfiguration}
import org.apache.hadoop.fs.{FileSystem, Path}
Expand Down Expand Up @@ -310,33 +314,33 @@ object SparkSubmit extends CommandLineUtils {
RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose)
}

// In client mode, download remote files.
if (deployMode == CLIENT) {
val hadoopConf = new HadoopConfiguration()
args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull
args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull
args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull
args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull
}

// Require all python files to be local, so we can add them to the PYTHONPATH
// In YARN cluster mode, python files are distributed as regular files, which can be non-local.
// In Mesos cluster mode, non-local python files are automatically downloaded by Mesos.
if (args.isPython && !isYarnCluster && !isMesosCluster) {
if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) {
printErrorAndExit(s"Only local python files are supported: ${args.primaryResource}")
val hadoopConf = new HadoopConfiguration()
val targetDir = Files.createTempDirectory("tmp").toFile
// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = {
FileUtils.deleteQuietly(targetDir)
}
val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",")
if (nonLocalPyFiles.nonEmpty) {
printErrorAndExit(s"Only local additional python files are supported: $nonLocalPyFiles")
}
}
})
// scalastyle:on runtimeaddshutdownhook

// Require all R files to be local
if (args.isR && !isYarnCluster && !isMesosCluster) {
if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) {
printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}")
}
// Resolve glob path for different resources.
args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull
args.files = Option(args.files).map(resolveGlobPaths(_, hadoopConf)).orNull
args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull
args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull

// In client mode, download remote files.
if (deployMode == CLIENT) {
args.primaryResource = Option(args.primaryResource).map {
downloadFile(_, targetDir, args.sparkProperties, hadoopConf)
}.orNull
args.jars = Option(args.jars).map {
downloadFileList(_, targetDir, args.sparkProperties, hadoopConf)
}.orNull
args.pyFiles = Option(args.pyFiles).map {
downloadFileList(_, targetDir, args.sparkProperties, hadoopConf)
}.orNull
}

// The following modes are not supported or applicable
Expand Down Expand Up @@ -841,36 +845,132 @@ object SparkSubmit extends CommandLineUtils {
* Download a list of remote files to temp local files. If the file is local, the original file
* will be returned.
* @param fileList A comma separated file list.
* @param targetDir A temporary directory for which downloaded files
* @param sparkProperties Spark properties
* @return A comma separated local files list.
*/
private[deploy] def downloadFileList(
fileList: String,
targetDir: File,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's explain the meaning of each param.

sparkProperties: Map[String, String],
hadoopConf: HadoopConfiguration): String = {
require(fileList != null, "fileList cannot be null.")
fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",")
fileList.split(",")
.map(downloadFile(_, targetDir, sparkProperties, hadoopConf))
.mkString(",")
}

/**
* Download a file from the remote to a local temporary directory. If the input path points to
* a local path, returns it with no operation.
* @param path A file path from where the files will be downloaded.
* @param targetDir A temporary directory for which downloaded files
* @param sparkProperties Spark properties
* @return A comma separated local files list.
*/
private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = {
private[deploy] def downloadFile(
path: String,
targetDir: File,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

sparkProperties: Map[String, String],
hadoopConf: HadoopConfiguration): String = {
require(path != null, "path cannot be null.")
val uri = Utils.resolveURI(path)
uri.getScheme match {
case "file" | "local" =>
path
case "file" | "local" => path
case "http" | "https" | "ftp" =>
val uc = uri.toURL.openConnection()
uc match {
case https: HttpsURLConnection =>
val trustStore = sparkProperties.get("spark.ssl.fs.trustStore")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this a common util?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should we test against this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a common util method in Utils, but the thing is that using that method as I did before will initialize Logging prematurely, that will have some issues as mentioned by @vanzin , so that's why I reimplemented here.

Do you have any thought?

.orElse(sparkProperties.get("spark.ssl.trustStore"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move these properties to internal/config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These configurations are dynamic configurations based on the component spark.ssl.${ns}.xxx, It is not easy to leverage internal/config. Let me think about it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense, let's keep this way.

val trustStorePwd = sparkProperties.get("spark.ssl.fs.trustStorePassword")
.orElse(sparkProperties.get("spark.ssl.trustStorePassword"))
.map(_.toCharArray)
.orNull
val protocol = sparkProperties.get("spark.ssl.fs.protocol")
.orElse(sparkProperties.get("spark.ssl.protocol"))
if (protocol.isEmpty) {
printErrorAndExit("spark ssl protocol is required when enabling SSL connection.")
}

val trustStoreManagers = trustStore.map { t =>
var input: InputStream = null
try {
input = new FileInputStream(new File(t))
val ks = KeyStore.getInstance(KeyStore.getDefaultType)
ks.load(input, trustStorePwd)
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
tmf.init(ks)
tmf.getTrustManagers
} finally {
if (input != null) {
input.close()
input = null
}
}
}.getOrElse {
Array({
new X509TrustManager {
override def getAcceptedIssuers: Array[X509Certificate] = null
override def checkClientTrusted(
x509Certificates: Array[X509Certificate], s: String) {}
override def checkServerTrusted(
x509Certificates: Array[X509Certificate], s: String) {}
}: TrustManager
})
}
val sslContext = SSLContext.getInstance(protocol.get)
sslContext.init(null, trustStoreManagers, null)
https.setSSLSocketFactory(sslContext.getSocketFactory)
https.setHostnameVerifier(new HostnameVerifier {
override def verify(s: String, sslSession: SSLSession): Boolean = false
})

case _ =>
}

uc.setConnectTimeout(60 * 1000)
uc.setReadTimeout(60 * 1000)
uc.connect()
val in = uc.getInputStream
val fileName = new Path(uri).getName
val tempFile = new File(targetDir, fileName)
val out = new FileOutputStream(tempFile)
// scalastyle:off println
printStream.println(s"Downloading ${uri.toString} to ${tempFile.getAbsolutePath}.")
// scalastyle:on println
try {
ByteStreams.copy(in, out)
} finally {
in.close()
out.close()
}
tempFile.toURI.toString
case _ =>
val fs = FileSystem.get(uri, hadoopConf)
val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath)
val tmpFile = new File(targetDir, new Path(uri).getName)
// scalastyle:off println
printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.")
// scalastyle:on println
fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath))
Utils.resolveURI(tmpFile.getAbsolutePath).toString
tmpFile.toURI.toString
}
}

private def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = {
require(paths != null, "paths cannot be null.")
paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path =>
val uri = Utils.resolveURI(path)
uri.getScheme match {
case "local" | "http" | "https" | "ftp" => Array(path)
case _ =>
val fs = FileSystem.get(uri, hadoopConf)
Option(fs.globStatus(new Path(uri))).map { status =>
status.filter(_.isFile).map(_.getPath.toUri.toString)
}.getOrElse(Array(path))
}
}.mkString(",")
}
}

/** Provides utility functions to be used inside SparkSubmit. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
| (Default: client).
| --class CLASS_NAME Your application's main class (for Java / Scala apps).
| --name NAME A name of your application.
| --jars JARS Comma-separated list of local jars to include on the driver
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new feature(support non-local jars), shall we create a separated jira ticket?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and we should also add a test case for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @cloud-fan for your review. Supporting remote jar has been addressed in #18078 , but it still left some codes and docs should be updated, also it doesn't support downloading jars from HTTP(S), FTP, here also adding this support.

So basically this PR address the left problem of #18078 and add glob support.

| --jars JARS Comma-separated list of jars to include on the driver
| and executor classpaths.
| --packages Comma-separated list of maven coordinates of jars to include
| on the driver and executor classpaths. Will search the local
Expand Down
68 changes: 59 additions & 9 deletions core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ package org.apache.spark.deploy
import java.io._
import java.net.URI
import java.nio.charset.StandardCharsets
import java.nio.file.Files

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.io.Source

import com.google.common.io.ByteStreams
import org.apache.commons.io.{FilenameUtils, FileUtils}
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.scalatest.{BeforeAndAfterEach, Matchers}
Expand All @@ -42,7 +44,6 @@ import org.apache.spark.TestUtils.JavaSourceFromString
import org.apache.spark.scheduler.EventLoggingListener
import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils}


trait TestPrematureExit {
suite: SparkFunSuite =>

Expand Down Expand Up @@ -726,6 +727,47 @@ class SparkSubmitSuite
Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar"))
Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar"))
}

test("support glob path") {
val tmpJarDir = Utils.createTempDir()
val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir)
val jar2 = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpJarDir)

val tmpFileDir = Utils.createTempDir()
val file1 = File.createTempFile("tmpFile1", "", tmpFileDir)
val file2 = File.createTempFile("tmpFile2", "", tmpFileDir)

val tmpPyFileDir = Utils.createTempDir()
val pyFile1 = File.createTempFile("tmpPy1", ".py", tmpPyFileDir)
val pyFile2 = File.createTempFile("tmpPy2", ".egg", tmpPyFileDir)

val tmpArchiveDir = Utils.createTempDir()
val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir)
val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir)

val args = Seq(
"--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"),
"--name", "testApp",
"--master", "yarn",
"--deploy-mode", "client",
"--jars", s"${tmpJarDir.getAbsolutePath}/*.jar",
"--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*",
"--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*",
"--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip",
jar2.toString)

val appArgs = new SparkSubmitArguments(args)
val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3
sysProps("spark.yarn.dist.jars").split(",").toSet should be
(Set(jar1.toURI.toString, jar2.toURI.toString))
sysProps("spark.yarn.dist.files").split(",").toSet should be
(Set(file1.toURI.toString, file2.toURI.toString))
sysProps("spark.submit.pyFiles").split(",").toSet should be
(Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath))
sysProps("spark.yarn.dist.archives").split(",").toSet should be
(Set(archive1.toURI.toString, archive2.toURI.toString))
}

// scalastyle:on println

private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = {
Expand All @@ -738,7 +780,7 @@ class SparkSubmitSuite
assert(outputUri.getScheme === "file")

// The path and filename are preserved.
assert(outputUri.getPath.endsWith(sourceUri.getPath))
assert(outputUri.getPath.endsWith(new Path(sourceUri).getName))
assert(FileUtils.readFileToString(new File(outputUri.getPath)) ===
FileUtils.readFileToString(new File(sourceUri.getPath)))
}
Expand All @@ -752,25 +794,29 @@ class SparkSubmitSuite

test("downloadFile - invalid url") {
intercept[IOException] {
SparkSubmit.downloadFile("abc:/my/file", new Configuration())
SparkSubmit.downloadFile(
"abc:/my/file", Utils.createTempDir(), mutable.Map.empty, new Configuration())
}
}

test("downloadFile - file doesn't exist") {
val hadoopConf = new Configuration()
val tmpDir = Utils.createTempDir()
// Set s3a implementation to local file system for testing.
hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
// Disable file system impl cache to make sure the test file system is picked up.
hadoopConf.set("fs.s3a.impl.disable.cache", "true")
intercept[FileNotFoundException] {
SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf)
SparkSubmit.downloadFile("s3a:/no/such/file", tmpDir, mutable.Map.empty, hadoopConf)
}
}

test("downloadFile does not download local file") {
// empty path is considered as local file.
assert(SparkSubmit.downloadFile("", new Configuration()) === "")
assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file")
val tmpDir = Files.createTempDirectory("tmp").toFile
assert(SparkSubmit.downloadFile("", tmpDir, mutable.Map.empty, new Configuration()) === "")
assert(SparkSubmit.downloadFile("/local/file", tmpDir, mutable.Map.empty,
new Configuration()) === "/local/file")
}

test("download one file to local") {
Expand All @@ -779,12 +825,14 @@ class SparkSubmitSuite
val content = "hello, world"
FileUtils.write(jarFile, content)
val hadoopConf = new Configuration()
val tmpDir = Files.createTempDirectory("tmp").toFile
// Set s3a implementation to local file system for testing.
hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
// Disable file system impl cache to make sure the test file system is picked up.
hadoopConf.set("fs.s3a.impl.disable.cache", "true")
val sourcePath = s"s3a://${jarFile.getAbsolutePath}"
val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf)
val outputPath =
SparkSubmit.downloadFile(sourcePath, tmpDir, mutable.Map.empty, hadoopConf)
checkDownloadedFile(sourcePath, outputPath)
deleteTempOutputFile(outputPath)
}
Expand All @@ -795,12 +843,14 @@ class SparkSubmitSuite
val content = "hello, world"
FileUtils.write(jarFile, content)
val hadoopConf = new Configuration()
val tmpDir = Files.createTempDirectory("tmp").toFile
// Set s3a implementation to local file system for testing.
hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
// Disable file system impl cache to make sure the test file system is picked up.
hadoopConf.set("fs.s3a.impl.disable.cache", "true")
val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}")
val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",")
val outputPaths = SparkSubmit.downloadFileList(
sourcePaths.mkString(","), tmpDir, mutable.Map.empty, hadoopConf).split(",")

assert(outputPaths.length === sourcePaths.length)
sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) =>
Expand Down
Loading