Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.deploy
import java.io.{File, IOException}
import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException}
import java.net.URL
import java.nio.file.Files
import java.security.PrivilegedExceptionAction
import java.text.ParseException

Expand All @@ -28,7 +29,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import scala.util.Properties

import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.{Configuration => HadoopConfiguration}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.security.UserGroupInformation
import org.apache.ivy.Ivy
import org.apache.ivy.core.LogOptions
Expand Down Expand Up @@ -308,6 +310,15 @@ object SparkSubmit extends CommandLineUtils {
RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose)
}

// In client mode, download remote files.
if (deployMode == CLIENT) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I may not have enough background knowledge, why we only do this for client mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems it can handle remote files in Yarn/Mesos cluster mode. I haven't tested it, because we are using client mode.

val hadoopConf = new HadoopConfiguration()
args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull
args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull
args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull
args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull
}

// Require all python files to be local, so we can add them to the PYTHONPATH
// In YARN cluster mode, python files are distributed as regular files, which can be non-local.
// In Mesos cluster mode, non-local python files are automatically downloaded by Mesos.
Expand Down Expand Up @@ -825,6 +836,41 @@ object SparkSubmit extends CommandLineUtils {
.mkString(",")
if (merged == "") null else merged
}

/**
* Download a list of remote files to temp local files. If the file is local, the original file
* will be returned.
* @param fileList A comma separated file list.
* @return A comma separated local files list.
*/
private[deploy] def downloadFileList(
fileList: String,
hadoopConf: HadoopConfiguration): String = {
require(fileList != null, "fileList cannot be null.")
fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",")
}

/**
* Download a file from the remote to a local temporary directory. If the input path points to
* a local path, returns it with no operation.
*/
private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = {
require(path != null, "path cannot be null.")
val uri = Utils.resolveURI(path)
uri.getScheme match {
case "file" | "local" =>
path

case _ =>
val fs = FileSystem.get(uri, hadoopConf)
val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath)
// scalastyle:off println
printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.")
// scalastyle:on println
fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath))
Utils.resolveURI(tmpFile.getAbsolutePath).toString
}
}
}

/** Provides utility functions to be used inside SparkSubmit. */
Expand Down
95 changes: 93 additions & 2 deletions core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
package org.apache.spark.deploy

import java.io._
import java.net.URI
import java.nio.charset.StandardCharsets

import scala.collection.mutable.ArrayBuffer
import scala.io.Source

import com.google.common.io.ByteStreams
import org.apache.commons.io.{FilenameUtils, FileUtils}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.scalatest.{BeforeAndAfterEach, Matchers}
import org.scalatest.concurrent.Timeouts
Expand Down Expand Up @@ -535,7 +538,7 @@ class SparkSubmitSuite

test("resolves command line argument paths correctly") {
val jars = "/jar1,/jar2" // --jars
val files = "hdfs:/file1,file2" // --files
val files = "local:/file1,file2" // --files
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you expand on why we are changing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make it not try to download file from hdfs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is kinda difficult to test download file from hdfs now, but we should cover this scene in the future.

val archives = "file:/archive1,archive2" // --archives
val pyFiles = "py-file1,py-file2" // --py-files

Expand Down Expand Up @@ -587,7 +590,7 @@ class SparkSubmitSuite

test("resolves config paths correctly") {
val jars = "/jar1,/jar2" // spark.jars
val files = "hdfs:/file1,file2" // spark.files / spark.yarn.dist.files
val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files
val archives = "file:/archive1,archive2" // spark.yarn.dist.archives
val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles

Expand Down Expand Up @@ -705,6 +708,87 @@ class SparkSubmitSuite
}
// scalastyle:on println

private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = {
if (sourcePath == outputPath) {
return
}

val sourceUri = new URI(sourcePath)
val outputUri = new URI(outputPath)
assert(outputUri.getScheme === "file")

// The path and filename are preserved.
assert(outputUri.getPath.endsWith(sourceUri.getPath))
assert(FileUtils.readFileToString(new File(outputUri.getPath)) ===
FileUtils.readFileToString(new File(sourceUri.getPath)))
}

private def deleteTempOutputFile(outputPath: String): Unit = {
val outputFile = new File(new URI(outputPath).getPath)
if (outputFile.exists) {
outputFile.delete()
}
}

test("downloadFile - invalid url") {
intercept[IOException] {
SparkSubmit.downloadFile("abc:/my/file", new Configuration())
}
}

test("downloadFile - file doesn't exist") {
val hadoopConf = new Configuration()
// Set s3a implementation to local file system for testing.
hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
// Disable file system impl cache to make sure the test file system is picked up.
hadoopConf.set("fs.s3a.impl.disable.cache", "true")
intercept[FileNotFoundException] {
SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf)
}
}

test("downloadFile does not download local file") {
// empty path is considered as local file.
assert(SparkSubmit.downloadFile("", new Configuration()) === "")
assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file")
}

test("download one file to local") {
val jarFile = File.createTempFile("test", ".jar")
jarFile.deleteOnExit()
val content = "hello, world"
FileUtils.write(jarFile, content)
val hadoopConf = new Configuration()
// Set s3a implementation to local file system for testing.
hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
// Disable file system impl cache to make sure the test file system is picked up.
hadoopConf.set("fs.s3a.impl.disable.cache", "true")
val sourcePath = s"s3a://${jarFile.getAbsolutePath}"
val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf)
checkDownloadedFile(sourcePath, outputPath)
deleteTempOutputFile(outputPath)
}

test("download list of files to local") {
val jarFile = File.createTempFile("test", ".jar")
jarFile.deleteOnExit()
val content = "hello, world"
FileUtils.write(jarFile, content)
val hadoopConf = new Configuration()
// Set s3a implementation to local file system for testing.
hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
// Disable file system impl cache to make sure the test file system is picked up.
hadoopConf.set("fs.s3a.impl.disable.cache", "true")
val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}")
val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",")

assert(outputPaths.length === sourcePaths.length)
sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) =>
checkDownloadedFile(sourcePath, outputPath)
deleteTempOutputFile(outputPath)
}
}

// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
private def runSparkSubmit(args: Seq[String]): Unit = {
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
Expand Down Expand Up @@ -807,3 +891,10 @@ object UserClasspathFirstTest {
}
}
}

class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem {
override def copyToLocalFile(src: Path, dst: Path): Unit = {
// Ignore the scheme for testing.
super.copyToLocalFile(new Path(src.toUri.getPath), dst)
}
}